-
Notifications
You must be signed in to change notification settings - Fork 2k
[None] [feat] Support multiple accuracy tasks for slurm scripts #10500
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -355,12 +355,33 @@ def submit_job(config, log_dir, dry_run): | |
| ] | ||
| client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd)) | ||
| if config['accuracy']['enable_accuracy_test']: | ||
| accuracy_cmd = [ | ||
| f"bash {env_config['work_dir']}/accuracy_eval.sh", | ||
| f"'{log_dir}' '{config['accuracy']['model']}' '{config['accuracy']['tasks']}' '{env_config['model_path']}' '{config['accuracy']['model_args_extra']}' '{log_dir}/accuracy_eval' {disagg_server_hostname} {disagg_server_port}", | ||
| f"&> {log_dir}/7_accuracy_eval.log" | ||
| ] | ||
| client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd)) | ||
| install_dep_cmd = "pip3 install lm_eval[api]==0.4.9.2" | ||
| client_cmds.append(" ".join(client_slurm_prefix) + " " + install_dep_cmd) | ||
| for task in config['accuracy']['tasks']: | ||
| extra_kwargs = config['accuracy']['tasks'][task].get('extra_kwargs', {}) | ||
| extra_kwargs_str = "" | ||
| for key, value in extra_kwargs.items(): | ||
| if isinstance(value, bool): | ||
| if value: | ||
| extra_kwargs_str += f" --{key}" | ||
| else: | ||
| extra_kwargs_str += f" --{key}='{value}'" | ||
| end_point_map = { | ||
| 'local-completions': 'v1/completions', | ||
| 'local-chat-completions': 'v1/chat/completions', | ||
| } | ||
| model = config['accuracy']['tasks'][task]['model'] | ||
| accuracy_cmd = [ | ||
| 'lm_eval', | ||
| '--model', model, | ||
| '--tasks', task, | ||
| '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}", | ||
| '--log_samples', | ||
| '--output_path', f'{log_dir}/accuracy_eval_{task}', | ||
| extra_kwargs_str, | ||
| f"&> {log_dir}/7_accuracy_eval_{task}.log" | ||
| ] | ||
| client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd)) | ||
|
Comment on lines
+358
to
+384
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Type mismatch between default config and new task iteration logic. The default accuracy config at lines 152-163 sets Additionally:
Proposed fixFirst, update the default config structure to match the new expected format (lines 152-163): if 'accuracy' not in config:
config['accuracy'] = {
- 'enable_accuracy_test':
- False,
- 'model':
- 'local-completions',
- 'tasks':
- 'gsm8k',
- 'model_args_extra':
- 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
+ 'enable_accuracy_test': False,
+ 'tasks': {
+ 'gsm8k': {
+ 'model': 'local-completions',
+ 'model_args_extra': 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
+ }
+ }
}Then, add validation and safe access for the task loop: if config['accuracy']['enable_accuracy_test']:
install_dep_cmd = "pip3 install lm_eval[api]==0.4.9.2"
client_cmds.append(" ".join(client_slurm_prefix) + " " + install_dep_cmd)
+ supported_models = {'local-completions', 'local-chat-completions'}
for task in config['accuracy']['tasks']:
extra_kwargs = config['accuracy']['tasks'][task].get('extra_kwargs', {})
extra_kwargs_str = ""
for key, value in extra_kwargs.items():
if isinstance(value, bool):
if value:
extra_kwargs_str += f" --{key}"
else:
extra_kwargs_str += f" --{key}='{value}'"
end_point_map = {
'local-completions': 'v1/completions',
'local-chat-completions': 'v1/chat/completions',
}
- model = config['accuracy']['tasks'][task]['model']
+ model = config['accuracy']['tasks'][task].get('model', 'local-completions')
+ if model not in supported_models:
+ raise ValueError(f"Unsupported model type '{model}' for task '{task}'. Supported: {supported_models}")
+ model_args_extra = config['accuracy']['tasks'][task].get('model_args_extra', '')
accuracy_cmd = [
'lm_eval',
'--model', model,
'--tasks', task,
- '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
+ '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{model_args_extra}",
'--log_samples',
'--output_path', f'{log_dir}/accuracy_eval_{task}',
extra_kwargs_str,
f"&> {log_dir}/7_accuracy_eval_{task}.log"
]
client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))
🤖 Prompt for AI Agents |
||
| with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f: | ||
| f.write("\n".join(client_cmds) + "\n") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trailing comma when
model_args_extrais empty.If
model_args_extrais an empty string, the--model_argsvalue will end with a trailing comma (e.g.,model=...,base_url=...,), which may cause parsing issues inlm_eval.Suggested fix
🤖 Prompt for AI Agents