Skip to content

Commit ef4ea95

Browse files
authored
[None] [fix] Fix slrum scripts (#10007)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent ad12b79 commit ef4ea95

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ while [[ $# -gt 0 ]]; do
1616
--benchmark-ratio) benchmark_ratio="$2"; shift 2 ;;
1717
--streaming) streaming="$2"; shift 2 ;;
1818
--use-nv-sa-benchmark) use_nv_sa_benchmark="$2"; shift 2 ;;
19+
--benchmark-mode) benchmark_mode="$2"; shift 2 ;;
1920

2021
# Environment and paths
2122
--dataset-file) dataset_file="$2"; shift 2 ;;
@@ -59,6 +60,7 @@ echo " multi_round: ${multi_round}"
5960
echo " benchmark_ratio: ${benchmark_ratio}"
6061
echo " streaming: ${streaming}"
6162
echo " use_nv_sa_benchmark: ${use_nv_sa_benchmark}"
63+
echo " benchmark_mode: ${benchmark_mode}"
6264
echo
6365
echo "Environment Configuration:"
6466
echo " dataset_file: ${dataset_file}"

examples/disaggregated/slurm/benchmark/gen_server_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,8 @@
7878
'port': args.server_port,
7979
'backend': 'pytorch',
8080
'context_servers': {
81-
'num_instances':
82-
0 if gen_only else args.num_ctx_servers,
83-
'urls': [] if gen_only else
84-
[f'{host}:{args.worker_port}' for host in ctx_hostnames]
81+
'num_instances': 0 if gen_only else args.num_ctx_servers,
82+
'urls': [] if gen_only else ctx_urls
8583
},
8684
'generation_servers': {
8785
'num_instances': args.num_gen_servers,

examples/disaggregated/slurm/benchmark/submit.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def parse_args():
2626
'--dir',
2727
type=str,
2828
help='Directory containing YAML configuration files')
29-
group.add_argument('--log-dir',
30-
type=str,
31-
default=None,
32-
help='Log directory')
29+
parser.add_argument('--log-dir',
30+
type=str,
31+
default=None,
32+
help='Log directory')
3333
return parser.parse_args()
3434

3535

@@ -154,16 +154,20 @@ def submit_job(config, log_dir):
154154
{}).get('num_nextn_predict_layers', 0)
155155

156156
# Calculate nodes based on world sizes
157-
ctx_tp_size = config['worker_config']['ctx']['tensor_parallel_size']
158-
ctx_cp_size = config['worker_config']['ctx']['context_parallel_size']
159-
ctx_pp_size = config['worker_config']['ctx']['pipeline_parallel_size']
157+
ctx_tp_size = config['worker_config']['ctx'].get('tensor_parallel_size', 1)
158+
ctx_cp_size = config['worker_config']['ctx'].get('context_parallel_size', 1)
159+
ctx_pp_size = config['worker_config']['ctx'].get('pipeline_parallel_size',
160+
1)
160161
ctx_world_size = ctx_tp_size * ctx_cp_size * ctx_pp_size
161162
ctx_nodes = calculate_nodes(ctx_world_size, ctx_num, gpus_per_node)
162-
gen_tp_size = config['worker_config']['gen']['tensor_parallel_size']
163-
gen_cp_size = config['worker_config']['gen']['context_parallel_size']
164-
gen_pp_size = config['worker_config']['gen']['pipeline_parallel_size']
163+
164+
gen_tp_size = config['worker_config']['gen'].get('tensor_parallel_size', 1)
165+
gen_cp_size = config['worker_config']['gen'].get('context_parallel_size', 1)
166+
gen_pp_size = config['worker_config']['gen'].get('pipeline_parallel_size',
167+
1)
165168
gen_world_size = gen_tp_size * gen_cp_size * gen_pp_size
166169
gen_nodes = calculate_nodes(gen_world_size, gen_num, gpus_per_node)
170+
167171
total_nodes = ctx_nodes + gen_nodes
168172
total_tasks = total_nodes * gpus_per_node
169173

@@ -259,7 +263,7 @@ def submit_job(config, log_dir):
259263
str(allocation["port"]),
260264
config['benchmark']['mode'],
261265
config['benchmark']['concurrency_list'],
262-
str(slurm_config['numa_bind']),
266+
str(slurm_config['numa_bind']).lower(),
263267
log_dir,
264268
str(profiling_config['nsys_on']).lower(),
265269
profiling_config['gen_profile_range']
@@ -303,6 +307,7 @@ def submit_job(config, log_dir):
303307
'--benchmark-ratio', str(config['benchmark']['benchmark_ratio']),
304308
'--streaming', str(config['benchmark']['streaming']).lower(),
305309
'--use-nv-sa-benchmark', str(config['benchmark']['use_nv_sa_benchmark']).lower(),
310+
'--benchmark-mode', config['benchmark']['mode'],
306311

307312
# Environment and paths
308313
'--dataset-file', config['benchmark']['dataset_file'],

0 commit comments

Comments
 (0)