Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions examples/disaggregated/slurm/benchmark/accuracy_eval.sh

This file was deleted.

33 changes: 27 additions & 6 deletions examples/disaggregated/slurm/benchmark/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,12 +355,33 @@ def submit_job(config, log_dir, dry_run):
]
client_cmds.append(" ".join(client_slurm_prefix + benchmark_cmd))
if config['accuracy']['enable_accuracy_test']:
accuracy_cmd = [
f"bash {env_config['work_dir']}/accuracy_eval.sh",
f"'{log_dir}' '{config['accuracy']['model']}' '{config['accuracy']['tasks']}' '{env_config['model_path']}' '{config['accuracy']['model_args_extra']}' '{log_dir}/accuracy_eval' {disagg_server_hostname} {disagg_server_port}",
f"&> {log_dir}/7_accuracy_eval.log"
]
client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))
install_dep_cmd = "pip3 install lm_eval[api]==0.4.9.2"
client_cmds.append(" ".join(client_slurm_prefix) + " " + install_dep_cmd)
for task in config['accuracy']['tasks']:
extra_kwargs = config['accuracy']['tasks'][task].get('extra_kwargs', {})
extra_kwargs_str = ""
for key, value in extra_kwargs.items():
if isinstance(value, bool):
if value:
extra_kwargs_str += f" --{key}"
else:
extra_kwargs_str += f" --{key}='{value}'"
end_point_map = {
'local-completions': 'v1/completions',
'local-chat-completions': 'v1/chat/completions',
}
model = config['accuracy']['tasks'][task]['model']
accuracy_cmd = [
'lm_eval',
'--model', model,
'--tasks', task,
'--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Trailing comma when model_args_extra is empty.

If model_args_extra is an empty string, the --model_args value will end with a trailing comma (e.g., model=...,base_url=...,), which may cause parsing issues in lm_eval.

Suggested fix
+            model_args_extra = config['accuracy']['tasks'][task].get('model_args_extra', '')
+            model_args_parts = [
+                f"model={env_config['model_path']}",
+                f"base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]}"
+            ]
+            if model_args_extra:
+                model_args_parts.append(model_args_extra)
             accuracy_cmd = [
                 'lm_eval',
                 '--model', model,
                 '--tasks', task,
-                '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
+                '--model_args', ','.join(model_args_parts),

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @examples/disaggregated/slurm/benchmark/submit.py at line 378, The generated
--model_args string can end with a trailing comma when
config['accuracy']['tasks'][task]['model_args_extra'] is empty; change the
construction in submit.py (the line that builds the '--model_args' arg using
env_config['model_path'], disagg_server_hostname, disagg_server_port,
end_point_map[model], and config['accuracy']['tasks'][task]['model_args_extra'])
to conditionally include the extra part only when non-empty (e.g., build a list
of parts like "model=...", "base_url=..." and append model_args_extra only if
truthy, then join with commas) so the final --model_args value never ends with
an extraneous comma.

'--log_samples',
'--output_path', f'{log_dir}/accuracy_eval_{task}',
extra_kwargs_str,
f"&> {log_dir}/7_accuracy_eval_{task}.log"
]
client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))
Comment on lines +358 to +384
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Type mismatch between default config and new task iteration logic.

The default accuracy config at lines 152-163 sets 'tasks': 'gsm8k' as a string, but the new code iterates over config['accuracy']['tasks'] as a dictionary (line 360). This will cause a runtime error when using the default config, as iterating over a string yields individual characters.

Additionally:

  1. Line 373: Accessing ['model'] without .get() will raise KeyError if not specified.
  2. Line 378: Accessing ['model_args_extra'] without .get() will raise KeyError if not specified.
  3. Line 378: end_point_map[model] will raise KeyError if model is not one of the two supported types.
Proposed fix

First, update the default config structure to match the new expected format (lines 152-163):

     if 'accuracy' not in config:
         config['accuracy'] = {
-            'enable_accuracy_test':
-            False,
-            'model':
-            'local-completions',
-            'tasks':
-            'gsm8k',
-            'model_args_extra':
-            'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
+            'enable_accuracy_test': False,
+            'tasks': {
+                'gsm8k': {
+                    'model': 'local-completions',
+                    'model_args_extra': 'num_concurrent=512,max_retries=3,tokenized_requests=false,timeout=1200,max_gen_toks=256,max_length=4096'
+                }
+            }
         }

Then, add validation and safe access for the task loop:

     if config['accuracy']['enable_accuracy_test']:
         install_dep_cmd = "pip3 install lm_eval[api]==0.4.9.2"
         client_cmds.append(" ".join(client_slurm_prefix) + " " + install_dep_cmd)
+        supported_models = {'local-completions', 'local-chat-completions'}
         for task in config['accuracy']['tasks']:
             extra_kwargs = config['accuracy']['tasks'][task].get('extra_kwargs', {})
             extra_kwargs_str = ""
             for key, value in extra_kwargs.items():
                 if isinstance(value, bool):
                     if value:
                         extra_kwargs_str += f" --{key}"
                 else:
                     extra_kwargs_str += f" --{key}='{value}'"
             end_point_map = {
                 'local-completions': 'v1/completions',
                 'local-chat-completions': 'v1/chat/completions',
             }
-            model = config['accuracy']['tasks'][task]['model']
+            model = config['accuracy']['tasks'][task].get('model', 'local-completions')
+            if model not in supported_models:
+                raise ValueError(f"Unsupported model type '{model}' for task '{task}'. Supported: {supported_models}")
+            model_args_extra = config['accuracy']['tasks'][task].get('model_args_extra', '')
             accuracy_cmd = [
                 'lm_eval',
                 '--model', model,
                 '--tasks', task,
-                '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{config['accuracy']['tasks'][task]['model_args_extra']}",
+                '--model_args', f"model={env_config['model_path']},base_url=http://{disagg_server_hostname}:{disagg_server_port}/{end_point_map[model]},{model_args_extra}",
                 '--log_samples',
                 '--output_path', f'{log_dir}/accuracy_eval_{task}',
                 extra_kwargs_str,
                 f"&> {log_dir}/7_accuracy_eval_{task}.log"
             ]
             client_cmds.append(" ".join(client_slurm_prefix + accuracy_cmd))

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @examples/disaggregated/slurm/benchmark/submit.py around lines 358 - 384,
Default config uses a string for config['accuracy']['tasks'] but the new loop
treats it like a dict, causing iteration and key errors; update the code to
normalize and validate tasks before the loop (e.g., if
isinstance(config['accuracy']['tasks'], str) convert it to
{config['accuracy']['tasks']: {}} or to a list/dict shape the rest of the code
expects), and inside the loop access per-task data safely using .get() (use
task_cfg = config['accuracy']['tasks'].get(task, {}) and model =
task_cfg.get('model', '<default_model>') and model_args_extra =
task_cfg.get('model_args_extra', '')), and guard end_point_map lookup with a
fallback or explicit validation (e.g., endpoint = end_point_map.get(model) and
raise/log a clear error if None) so the code never assumes keys exist.

with open(os.path.join(log_dir, "client_cmds.sh"), "w") as f:
f.write("\n".join(client_cmds) + "\n")

Expand Down
Loading