Skip to content

Commit 627601f

Browse files
dc3671codego7250
authored andcommitted
[None][chore] refactor disaggregated scripts to use named arguments (NVIDIA#9581)
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
1 parent cda0749 commit 627601f

File tree

2 files changed

+93
-90
lines changed

2 files changed

+93
-90
lines changed

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 54 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,60 @@
11
#!/bin/bash
22
set -euo pipefail
33

4-
# Parse arguments
5-
# Hardware configuration
6-
gpus_per_node=${1}
7-
numa_bind=${2}
8-
ctx_nodes=${3} # Number of nodes needed for ctx workers
9-
gen_nodes=${4} # Number of nodes needed for gen workers
10-
ctx_world_size=${5} # World size for ctx workers
11-
gen_world_size=${6} # World size for gen workers
12-
13-
# Worker configuration
14-
num_ctx_servers=${7}
15-
ctx_config_path=${8}
16-
num_gen_servers=${9}
17-
gen_config_path=${10}
18-
concurrency_list=${11}
19-
20-
# Sequence and benchmark parameters
21-
isl=${12}
22-
osl=${13}
23-
multi_round=${14}
24-
benchmark_ratio=${15}
25-
streaming=${16}
26-
use_nv_sa_benchmark=${17}
27-
benchmark_mode=${18}
28-
cache_max_tokens=${19}
29-
30-
# Environment and paths
31-
dataset_file=${20}
32-
model_path=${21}
33-
trtllm_repo=${22}
34-
work_dir=${23}
35-
full_logdir=${24}
36-
container_mount=${25}
37-
container_image=${26}
38-
build_wheel=${27}
39-
trtllm_wheel_path=${28}
40-
41-
# Profiling
42-
nsys_on=${29}
43-
ctx_profile_range=${30}
44-
gen_profile_range=${31}
45-
46-
# Accuracy evaluation
47-
enable_accuracy_test=${32}
48-
accuracy_model=${33}
49-
accuracy_tasks=${34}
50-
model_args_extra=${35}
51-
52-
# Worker environment variables
53-
worker_env_var=${36}
54-
55-
# Server environment variables
56-
server_env_var=${37}
4+
# Parse named arguments
5+
while [[ $# -gt 0 ]]; do
6+
case $1 in
7+
# Hardware configuration
8+
--gpus-per-node) gpus_per_node="$2"; shift 2 ;;
9+
--numa-bind) numa_bind="$2"; shift 2 ;;
10+
--ctx-nodes) ctx_nodes="$2"; shift 2 ;;
11+
--gen-nodes) gen_nodes="$2"; shift 2 ;;
12+
--ctx-world-size) ctx_world_size="$2"; shift 2 ;;
13+
--gen-world-size) gen_world_size="$2"; shift 2 ;;
14+
# Worker configuration
15+
--num-ctx-servers) num_ctx_servers="$2"; shift 2 ;;
16+
--ctx-config-path) ctx_config_path="$2"; shift 2 ;;
17+
--num-gen-servers) num_gen_servers="$2"; shift 2 ;;
18+
--gen-config-path) gen_config_path="$2"; shift 2 ;;
19+
--concurrency-list) concurrency_list="$2"; shift 2 ;;
20+
# Sequence and benchmark parameters
21+
--isl) isl="$2"; shift 2 ;;
22+
--osl) osl="$2"; shift 2 ;;
23+
--multi-round) multi_round="$2"; shift 2 ;;
24+
--benchmark-ratio) benchmark_ratio="$2"; shift 2 ;;
25+
--streaming) streaming="$2"; shift 2 ;;
26+
--use-nv-sa-benchmark) use_nv_sa_benchmark="$2"; shift 2 ;;
27+
--benchmark-mode) benchmark_mode="$2"; shift 2 ;;
28+
--cache-max-tokens) cache_max_tokens="$2"; shift 2 ;;
29+
# Environment and paths
30+
--dataset-file) dataset_file="$2"; shift 2 ;;
31+
--model-path) model_path="$2"; shift 2 ;;
32+
--trtllm-repo) trtllm_repo="$2"; shift 2 ;;
33+
--work-dir) work_dir="$2"; shift 2 ;;
34+
--full-logdir) full_logdir="$2"; shift 2 ;;
35+
--container-mount) container_mount="$2"; shift 2 ;;
36+
--container-image) container_image="$2"; shift 2 ;;
37+
--build-wheel) build_wheel="$2"; shift 2 ;;
38+
--trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;;
39+
# Profiling
40+
--nsys-on) nsys_on="$2"; shift 2 ;;
41+
--ctx-profile-range) ctx_profile_range="$2"; shift 2 ;;
42+
--gen-profile-range) gen_profile_range="$2"; shift 2 ;;
43+
# Accuracy evaluation
44+
--enable-accuracy-test) enable_accuracy_test="$2"; shift 2 ;;
45+
--accuracy-model) accuracy_model="$2"; shift 2 ;;
46+
--accuracy-tasks) accuracy_tasks="$2"; shift 2 ;;
47+
--model-args-extra) model_args_extra="$2"; shift 2 ;;
48+
# Worker environment variables
49+
--worker-env-var) worker_env_var="$2"; shift 2 ;;
50+
# Server environment variables
51+
--server-env-var) server_env_var="$2"; shift 2 ;;
52+
*)
53+
echo "Unknown argument: $1"
54+
exit 1
55+
;;
56+
esac
57+
done
5758

5859
# Print all parsed arguments
5960
echo "Parsed arguments:"

examples/disaggregated/slurm/benchmark/submit.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def submit_job(config, log_dir):
150150
save_worker_config(config, gen_config_path, 'gen')
151151

152152
# Prepare sbatch command
153+
# yapf: disable
153154
cmd = [
154155
'sbatch',
155156
f'--partition={slurm_config["partition"]}',
@@ -163,59 +164,60 @@ def submit_job(config, log_dir):
163164
*([arg for arg in slurm_config['extra_args'].split() if arg]),
164165
slurm_config['script_file'],
165166
# Hardware configuration
166-
str(hw_config['gpus_per_node']),
167-
str(slurm_config['numa_bind']).lower(),
168-
str(ctx_nodes), # Number of nodes needed for ctx workers
169-
str(gen_nodes), # Number of nodes needed for gen workers
170-
str(ctx_world_size), # World size for ctx workers
171-
str(gen_world_size), # World size for gen workers
167+
'--gpus-per-node', str(hw_config['gpus_per_node']),
168+
'--numa-bind', str(slurm_config['numa_bind']).lower(),
169+
'--ctx-nodes', str(ctx_nodes), # Number of nodes needed for ctx workers
170+
'--gen-nodes', str(gen_nodes), # Number of nodes needed for gen workers
171+
'--ctx-world-size', str(ctx_world_size), # World size for ctx workers
172+
'--gen-world-size', str(gen_world_size), # World size for gen workers
172173

173174
# Worker configuration
174-
str(ctx_num),
175-
ctx_config_path,
176-
str(gen_num),
177-
gen_config_path,
178-
config['benchmark']['concurrency_list'],
175+
'--num-ctx-servers', str(ctx_num),
176+
'--ctx-config-path', ctx_config_path,
177+
'--num-gen-servers', str(gen_num),
178+
'--gen-config-path', gen_config_path,
179+
'--concurrency-list', config['benchmark']['concurrency_list'],
179180

180181
# Sequence and benchmark parameters
181-
str(config['benchmark']['input_length']),
182-
str(config['benchmark']['output_length']),
183-
str(config['benchmark']['multi_round']),
184-
str(config['benchmark']['benchmark_ratio']),
185-
str(config['benchmark']['streaming']).lower(),
186-
str(config['benchmark']['use_nv_sa_benchmark']).lower(),
187-
config['benchmark']['mode'],
188-
str(config['worker_config']['gen']['cache_transceiver_config']
182+
'--isl', str(config['benchmark']['input_length']),
183+
'--osl', str(config['benchmark']['output_length']),
184+
'--multi-round', str(config['benchmark']['multi_round']),
185+
'--benchmark-ratio', str(config['benchmark']['benchmark_ratio']),
186+
'--streaming', str(config['benchmark']['streaming']).lower(),
187+
'--use-nv-sa-benchmark', str(config['benchmark']['use_nv_sa_benchmark']).lower(),
188+
'--benchmark-mode', config['benchmark']['mode'],
189+
'--cache-max-tokens', str(config['worker_config']['gen']['cache_transceiver_config']
189190
['max_tokens_in_buffer']),
190191

191192
# Environment and paths
192-
config['benchmark']['dataset_file'],
193-
env_config['model_path'],
194-
env_config['trtllm_repo'],
195-
env_config['work_dir'],
196-
log_dir, # Pass the generated log directory
197-
env_config['container_mount'],
198-
env_config['container_image'],
199-
str(env_config['build_wheel']).lower(),
200-
env_config['trtllm_wheel_path'],
193+
'--dataset-file', config['benchmark']['dataset_file'],
194+
'--model-path', env_config['model_path'],
195+
'--trtllm-repo', env_config['trtllm_repo'],
196+
'--work-dir', env_config['work_dir'],
197+
'--full-logdir', log_dir,
198+
'--container-mount', env_config['container_mount'],
199+
'--container-image', env_config['container_image'],
200+
'--build-wheel', str(env_config['build_wheel']).lower(),
201+
'--trtllm-wheel-path', env_config['trtllm_wheel_path'],
201202

202203
# Profiling
203-
str(profiling_config['nsys_on']).lower(),
204-
profiling_config['ctx_profile_range'],
205-
profiling_config['gen_profile_range'],
204+
'--nsys-on', str(profiling_config['nsys_on']).lower(),
205+
'--ctx-profile-range', profiling_config['ctx_profile_range'],
206+
'--gen-profile-range', profiling_config['gen_profile_range'],
206207

207208
# Accuracy evaluation
208-
str(config['accuracy']['enable_accuracy_test']).lower(),
209-
config['accuracy']['model'],
210-
config['accuracy']['tasks'],
211-
config['accuracy']['model_args_extra'],
209+
'--enable-accuracy-test', str(config['accuracy']['enable_accuracy_test']).lower(),
210+
'--accuracy-model', config['accuracy']['model'],
211+
'--accuracy-tasks', config['accuracy']['tasks'],
212+
'--model-args-extra', config['accuracy']['model_args_extra'],
212213

213214
# Worker environment variables
214-
env_config['worker_env_var'],
215+
'--worker-env-var', env_config['worker_env_var'],
215216

216217
# Server environment variables
217-
env_config['server_env_var']
218+
'--server-env-var', env_config['server_env_var']
218219
]
220+
# yapf: enable
219221

220222
# Submit the job
221223
try:

0 commit comments

Comments
 (0)