Skip to content

Commit 110820b

Browse files
authored
[TRTLLM-9792] [feat] Support multiple instances on single node for slurm scripts (#9900)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent bd441e9 commit 110820b

File tree

4 files changed

+186
-146
lines changed

4 files changed

+186
-146
lines changed

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 44 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,19 @@ set -euo pipefail
44
# Parse named arguments
55
while [[ $# -gt 0 ]]; do
66
case $1 in
7-
# Hardware configuration
8-
--gpus-per-node) gpus_per_node="$2"; shift 2 ;;
9-
--numa-bind) numa_bind="$2"; shift 2 ;;
10-
--ctx-nodes) ctx_nodes="$2"; shift 2 ;;
11-
--gen-nodes) gen_nodes="$2"; shift 2 ;;
12-
--ctx-world-size) ctx_world_size="$2"; shift 2 ;;
13-
--gen-world-size) gen_world_size="$2"; shift 2 ;;
147
# Worker configuration
158
--num-ctx-servers) num_ctx_servers="$2"; shift 2 ;;
16-
--ctx-config-path) ctx_config_path="$2"; shift 2 ;;
179
--num-gen-servers) num_gen_servers="$2"; shift 2 ;;
18-
--gen-config-path) gen_config_path="$2"; shift 2 ;;
1910
--concurrency-list) concurrency_list="$2"; shift 2 ;;
11+
2012
# Sequence and benchmark parameters
2113
--isl) isl="$2"; shift 2 ;;
2214
--osl) osl="$2"; shift 2 ;;
2315
--multi-round) multi_round="$2"; shift 2 ;;
2416
--benchmark-ratio) benchmark_ratio="$2"; shift 2 ;;
2517
--streaming) streaming="$2"; shift 2 ;;
2618
--use-nv-sa-benchmark) use_nv_sa_benchmark="$2"; shift 2 ;;
27-
--benchmark-mode) benchmark_mode="$2"; shift 2 ;;
28-
--cache-max-tokens) cache_max_tokens="$2"; shift 2 ;;
19+
2920
# Environment and paths
3021
--dataset-file) dataset_file="$2"; shift 2 ;;
3122
--model-path) model_path="$2"; shift 2 ;;
@@ -36,17 +27,13 @@ while [[ $# -gt 0 ]]; do
3627
--container-image) container_image="$2"; shift 2 ;;
3728
--build-wheel) build_wheel="$2"; shift 2 ;;
3829
--trtllm-wheel-path) trtllm_wheel_path="$2"; shift 2 ;;
39-
# Profiling
40-
--nsys-on) nsys_on="$2"; shift 2 ;;
41-
--ctx-profile-range) ctx_profile_range="$2"; shift 2 ;;
42-
--gen-profile-range) gen_profile_range="$2"; shift 2 ;;
30+
4331
# Accuracy evaluation
4432
--enable-accuracy-test) enable_accuracy_test="$2"; shift 2 ;;
4533
--accuracy-model) accuracy_model="$2"; shift 2 ;;
4634
--accuracy-tasks) accuracy_tasks="$2"; shift 2 ;;
4735
--model-args-extra) model_args_extra="$2"; shift 2 ;;
48-
# Worker environment variables
49-
--worker-env-var) worker_env_var="$2"; shift 2 ;;
36+
5037
# Server environment variables
5138
--server-env-var) server_env_var="$2"; shift 2 ;;
5239
*)
@@ -58,60 +45,42 @@ done
5845

5946
# Print all parsed arguments
6047
echo "Parsed arguments:"
61-
echo "Hardware Configuration:"
62-
echo " gpus_per_node: ${gpus_per_node}"
63-
echo " numa_bind: ${numa_bind}"
64-
echo " ctx_nodes: ${ctx_nodes}"
65-
echo " gen_nodes: ${gen_nodes}"
66-
echo " ctx_world_size: ${ctx_world_size}"
67-
echo " gen_world_size: ${gen_world_size}"
6848
echo
6949
echo "Worker Configuration:"
7050
echo " num_ctx_servers: ${num_ctx_servers}"
71-
echo " ctx_config_path: ${ctx_config_path}"
7251
echo " num_gen_servers: ${num_gen_servers}"
73-
echo " gen_config_path: ${gen_config_path}"
7452
echo " concurrency_list: ${concurrency_list}"
7553
echo
7654
echo "Benchmark Configuration:"
77-
echo " use_nv_sa_benchmark: ${use_nv_sa_benchmark}"
7855
echo " isl: ${isl}"
7956
echo " osl: ${osl}"
8057
echo " multi_round: ${multi_round}"
8158
echo " benchmark_ratio: ${benchmark_ratio}"
8259
echo " streaming: ${streaming}"
83-
echo " cache_max_tokens: ${cache_max_tokens}"
84-
echo " benchmark_mode: ${benchmark_mode}"
60+
echo " use_nv_sa_benchmark: ${use_nv_sa_benchmark}"
8561
echo
8662
echo "Environment Configuration:"
8763
echo " dataset_file: ${dataset_file}"
88-
echo " container_mount: ${container_mount}"
89-
echo " container_image: ${container_image}"
9064
echo " model_path: ${model_path}"
9165
echo " trtllm_repo: ${trtllm_repo}"
66+
echo " work_dir: ${work_dir}"
67+
echo " full_logdir: ${full_logdir}"
68+
echo " container_mount: ${container_mount}"
69+
echo " container_image: ${container_image}"
9270
echo " build_wheel: ${build_wheel}"
9371
echo " trtllm_wheel_path: ${trtllm_wheel_path}"
94-
echo " work_dir: ${work_dir}"
95-
echo " nsys_on: ${nsys_on}"
96-
echo " ctx_profile_range: ${ctx_profile_range}"
97-
echo " gen_profile_range: ${gen_profile_range}"
9872
echo
9973
echo "Accuracy Configuration:"
10074
echo " enable_accuracy_test: ${enable_accuracy_test}"
10175
echo " accuracy_model: ${accuracy_model}"
10276
echo " accuracy_tasks: ${accuracy_tasks}"
10377
echo " model_args_extra: ${model_args_extra}"
10478
echo
105-
echo "Worker Environment Variables:"
106-
echo " worker_env_var: ${worker_env_var}"
107-
echo
10879
echo "Server Environment Variables:"
10980
echo " server_env_var: ${server_env_var}"
11081

11182
container_name="disaggr-test"
11283

113-
echo "Log directory: ${full_logdir}"
114-
11584
# Function to cleanup on failure
11685
cleanup_on_failure() {
11786
echo "Error: $1"
@@ -128,8 +97,8 @@ if ! srun -l --container-image=${container_image} \
12897
--container-name=${container_name} \
12998
--container-mounts=${container_mount} \
13099
--mpi=pmix \
131-
echo "Container up." &> ${full_logdir}/container_launch.log; then
132-
cleanup_on_failure "Failed to start container. Check ${full_logdir}/container_launch.log"
100+
echo "Container up." &> ${full_logdir}/1_container_launch.log; then
101+
cleanup_on_failure "Failed to start container. Check ${full_logdir}/1_container_launch.log"
133102
fi
134103

135104
# Install TensorRT-LLM
@@ -140,8 +109,8 @@ if [ -n "${trtllm_wheel_path}" ]; then
140109
--container-mounts=${container_mount} --no-container-mount-home \
141110
--mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \
142111
bash -c "pip install ${trtllm_wheel_path}" \
143-
&> ${full_logdir}/install.log; then
144-
cleanup_on_failure "TensorRT-LLM wheel installation failed. Check ${full_logdir}/install.log for details"
112+
&> ${full_logdir}/2_install.log; then
113+
cleanup_on_failure "TensorRT-LLM wheel installation failed. Check ${full_logdir}/2_install.log for details"
145114
fi
146115
echo "TensorRT-LLM wheel installation completed successfully"
147116
elif [ -d "${trtllm_repo}" ]; then
@@ -157,8 +126,8 @@ elif [ -d "${trtllm_repo}" ]; then
157126
--container-mounts=${container_mount} \
158127
--mpi=pmix --overlap -N 1 --ntasks-per-node=1 \
159128
bash -c "cd ${trtllm_repo} && ${build_command}" \
160-
&> ${full_logdir}/build.log; then
161-
cleanup_on_failure "TensorRT-LLM build failed. Check ${full_logdir}/build.log for details"
129+
&> ${full_logdir}/2_build.log; then
130+
cleanup_on_failure "TensorRT-LLM build failed. Check ${full_logdir}/2_build.log for details"
162131
fi
163132
echo "TensorRT-LLM build completed successfully"
164133
fi
@@ -168,60 +137,33 @@ elif [ -d "${trtllm_repo}" ]; then
168137
--container-mounts=${container_mount} --no-container-mount-home \
169138
--mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \
170139
bash -c "cd ${trtllm_repo} && pip install -e ." \
171-
&> ${full_logdir}/install.log; then
172-
cleanup_on_failure "TensorRT-LLM installation failed. Check ${full_logdir}/install.log for details"
140+
&> ${full_logdir}/2_install.log; then
141+
cleanup_on_failure "TensorRT-LLM installation failed. Check ${full_logdir}/2_install.log for details"
173142
fi
174143
echo "TensorRT-LLM installation completed successfully"
175144
fi
176145

177-
# Get node lists
146+
# Get node lists and replace the placeholder with the actual node names
147+
echo "SLURM_NODELIST: ${SLURM_NODELIST}"
178148
all_nodes=($(scontrol show hostname $SLURM_NODELIST | sort))
179-
total_nodes_num=${#all_nodes[@]}
180-
echo "all_nodes: ${all_nodes[@]}, total_nodes_num: ${total_nodes_num}"
181-
182-
# Split nodes between gen and ctx workers
183-
gen_node_list=(${all_nodes[@]:0:${gen_nodes}})
184-
ctx_node_list=(${all_nodes[@]:${gen_nodes}:${total_nodes_num}})
185-
186-
echo "gen_nodes: ${gen_node_list[@]}, num_nodes: ${gen_nodes}"
187-
echo "ctx_nodes: ${ctx_node_list[@]}, num_nodes: ${ctx_nodes}"
188-
189-
rm -rf ${full_logdir}/hostnames
190-
rm -rf ${full_logdir}/server_config.yaml
191-
192-
gen_nodes_num_in_single_server=$((${gen_nodes} / ${num_gen_servers}))
193-
ctx_nodes_num_in_single_server=$((${ctx_nodes} / ${num_ctx_servers}))
194-
echo "gen_nodes_num_in_single_server: ${gen_nodes_num_in_single_server}"
195-
echo "ctx_nodes_num_in_single_server: ${ctx_nodes_num_in_single_server}"
196-
197-
# start the gen workers
198-
echo "Starting gen workers..."
199-
for i in $(seq 0 $((num_gen_servers - 1))); do
200-
srun -l -N ${gen_nodes_num_in_single_server} \
201-
--ntasks=$((gen_world_size)) \
202-
--ntasks-per-node=${gpus_per_node} \
203-
--container-image=${container_image} \
204-
--container-name=${container_name} \
205-
--container-mounts=${container_mount} \
206-
--mpi=pmix \
207-
bash ${work_dir}/start_worker.sh \
208-
"GEN" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${gen_profile_range}" "${gen_config_path}" "${worker_env_var}" \
209-
&> ${full_logdir}/output_gen_${i}.log &
149+
all_nodes_str=$(IFS=','; echo "${all_nodes[*]}")
150+
echo "all_nodes_str: ${all_nodes_str}"
151+
152+
start_worker_cmds_file=${full_logdir}/start_worker_cmds.txt
153+
IFS=',' read -r -a node_array <<< "$all_nodes_str"
154+
for i in "${!node_array[@]}"; do
155+
current_val="${node_array[$i]}"
156+
placeholder="<node${i}_placeholder>"
157+
158+
# Use sed to replace the placeholder with the value in-place
159+
sed -i "s|$placeholder|$current_val|g" "${start_worker_cmds_file}"
160+
echo "Replaced $placeholder with $current_val"
210161
done
211162

212-
# start the ctx workers
213-
echo "Starting ctx workers..."
214-
for i in $(seq 0 $((num_ctx_servers - 1))); do
215-
srun -l -N ${ctx_nodes_num_in_single_server} \
216-
--ntasks=$((ctx_world_size)) \
217-
--ntasks-per-node=${gpus_per_node} \
218-
--container-image=${container_image} \
219-
--container-name=${container_name} \
220-
--container-mounts=${container_mount} \
221-
--mpi=pmix \
222-
bash ${work_dir}/start_worker.sh \
223-
"CTX" ${i} ${model_path} "8336" "${benchmark_mode}" "${concurrency_list}" "${numa_bind}" "${full_logdir}" "${nsys_on}" "${ctx_profile_range}" "${ctx_config_path}" "${worker_env_var}" \
224-
&> ${full_logdir}/output_ctx_${i}.log &
163+
echo "Starting worker commands from ${start_worker_cmds_file}..."
164+
cat ${start_worker_cmds_file} | while read cmd; do
165+
echo "Starting worker command: ${cmd}"
166+
eval "${cmd}"
225167
done
226168

227169
# start the server (in background)
@@ -231,16 +173,16 @@ srun -l --container-name=${container_name} \
231173
--container-mounts=${container_mount} \
232174
--mpi=pmix --overlap -N 1 -n 1 \
233175
bash ${work_dir}/start_server.sh ${num_ctx_servers} ${num_gen_servers} ${full_logdir} ${work_dir} "${server_env_var}" \
234-
&> ${full_logdir}/output_server.log &
176+
&> ${full_logdir}/4_output_server.log &
235177

236178
# Wait for server to be ready (runs synchronously)
237179
echo "Waiting for server to be ready..."
238180
if ! srun -l --container-name=${container_name} \
239181
--container-mounts=${container_mount} \
240182
--mpi=pmix --overlap -N 1 -n 1 \
241183
bash ${work_dir}/wait_server.sh ${full_logdir} \
242-
&> ${full_logdir}/wait_server.log; then
243-
cleanup_on_failure "Server failed to become ready. Check ${full_logdir}/wait_server.log for details"
184+
&> ${full_logdir}/5_wait_server.log; then
185+
cleanup_on_failure "Server failed to become ready. Check ${full_logdir}/5_wait_server.log for details"
244186
fi
245187
echo "Server is ready!"
246188

@@ -253,8 +195,8 @@ if [ "${use_nv_sa_benchmark}" = "true" ]; then
253195
--mpi=pmix --overlap -N 1 -n 1 \
254196
bash ${work_dir}/run_benchmark_nv_sa.sh \
255197
"${model_path}" "${isl}" "${osl}" "${benchmark_ratio}" "${multi_round}" "${num_gen_servers}" "${concurrency_list}" "${streaming}" "${full_logdir}/" \
256-
&> ${full_logdir}/bench.log; then
257-
cleanup_on_failure "NVIDIA SA benchmark failed. Check ${full_logdir}/bench.log for details"
198+
&> ${full_logdir}/6_bench.log; then
199+
cleanup_on_failure "NVIDIA SA benchmark failed. Check ${full_logdir}/6_bench.log for details"
258200
fi
259201
else
260202
echo "Using default benchmark script..."
@@ -263,8 +205,8 @@ else
263205
--mpi=pmix --overlap -N 1 -n 1 \
264206
bash ${work_dir}/run_benchmark.sh \
265207
"${model_path}" "${dataset_file}" "${multi_round}" "${num_gen_servers}" "${concurrency_list}" "${streaming}" "${full_logdir}/" \
266-
&> ${full_logdir}/bench.log; then
267-
cleanup_on_failure "Benchmark failed. Check ${full_logdir}/bench.log for details"
208+
&> ${full_logdir}/6_bench.log; then
209+
cleanup_on_failure "Benchmark failed. Check ${full_logdir}/6_bench.log for details"
268210
fi
269211
fi
270212
echo "Benchmark completed successfully"
@@ -278,8 +220,8 @@ if [ "${enable_accuracy_test}" = "true" ]; then
278220
bash ${work_dir}/accuracy_eval.sh \
279221
"${full_logdir}" "${accuracy_model}" "${accuracy_tasks}" "${model_path}" \
280222
"${model_args_extra}" "${full_logdir}/accuracy_eval" \
281-
&> ${full_logdir}/accuracy_eval.log; then
282-
cleanup_on_failure "Accuracy evaluation failed. Check ${full_logdir}/accuracy_eval.log for details"
223+
&> ${full_logdir}/7_accuracy_eval.log; then
224+
cleanup_on_failure "Accuracy evaluation failed. Check ${full_logdir}/7_accuracy_eval.log for details"
283225
fi
284226
echo "Accuracy evaluation completed successfully"
285227
fi

examples/disaggregated/slurm/benchmark/gen_server_config.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919
type=str,
2020
default="logs",
2121
help="Work directory")
22-
parser.add_argument("--worker_port",
23-
type=int,
24-
default=8336,
25-
help="Worker port")
2622
parser.add_argument("--server_port",
2723
type=int,
2824
default=8333,
@@ -49,21 +45,21 @@
4945
print(f"All hostnames found in {hostnames_folder}")
5046

5147
# get the ctx and gen hostnames from the hostnames file
52-
ctx_hostnames = []
53-
gen_hostnames = []
48+
ctx_urls = []
49+
gen_urls = []
5450
for hostname_file in hostnames:
5551
hostname_file_path = os.path.join(hostnames_folder, hostname_file)
5652
with open(hostname_file_path, 'r') as f:
57-
actual_hostname = f.read().strip()
58-
print(f"Hostname: {actual_hostname} in {hostname_file}")
53+
url = f.read().strip()
54+
print(f"url: {url} in {hostname_file}")
5955

60-
if hostname_file.startswith("CTX"):
61-
ctx_hostnames.append(actual_hostname)
62-
elif hostname_file.startswith("GEN"):
63-
gen_hostnames.append(actual_hostname)
56+
if hostname_file.startswith("CTX"):
57+
ctx_urls.append(url)
58+
elif hostname_file.startswith("GEN"):
59+
gen_urls.append(url)
6460

65-
print(f"ctx_hostnames: {ctx_hostnames}")
66-
print(f"gen_hostnames: {gen_hostnames}")
61+
print(f"ctx_urls: {ctx_urls}")
62+
print(f"gen_urls: {gen_urls}")
6763

6864
# get current hostname from env
6965
hostname = socket.gethostname()
@@ -75,11 +71,11 @@
7571
'backend': 'pytorch',
7672
'context_servers': {
7773
'num_instances': args.num_ctx_servers,
78-
'urls': [f'{host}:{args.worker_port}' for host in ctx_hostnames]
74+
'urls': ctx_urls
7975
},
8076
'generation_servers': {
8177
'num_instances': args.num_gen_servers,
82-
'urls': [f'{host}:{args.worker_port}' for host in gen_hostnames]
78+
'urls': gen_urls
8379
}
8480
}
8581

examples/disaggregated/slurm/benchmark/start_worker.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ echo "config_file: ${config_file}"
4343
# if SLURM_NODEID is 0, save the hostname to a file
4444
if [ "${SLURM_NODEID}" = "0" ]; then
4545
mkdir -p ${log_dir}/hostnames/
46-
echo $(hostname) > ${log_dir}/hostnames/${role}_${instance_id}.txt
47-
echo "hostname saved to ${log_dir}/hostnames/${role}_${instance_id}.txt"
46+
echo $(hostname):${port} > ${log_dir}/hostnames/${role}_${instance_id}.txt
47+
echo "hostname:port saved to ${log_dir}/hostnames/${role}_${instance_id}.txt"
4848
fi
4949

5050
nsys_prefix=""

0 commit comments

Comments
 (0)