Skip to content

Commit 6eef192

Browse files
authored
[None] [chore] cherry pick changes on slurm scripts from release/1.1.0rc2 (#7750)
Signed-off-by: Kaiyu Xie <[email protected]>
1 parent b278d06 commit 6eef192

File tree

4 files changed

+24
-12
lines changed

4 files changed

+24
-12
lines changed

examples/disaggregated/slurm/benchmark/disaggr_torch.slurm

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ echo "ntasks_per_node: ${ntasks_per_node}"
8383
echo "==========================================="
8484

8585

86-
nsys_on=""
87-
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
8886
numa_bind=true # Only allocate memory from nodes, this only works on GB200
8987
ctx_max_seq_len=$((isl + 10))
9088
gen_max_seq_len=$((isl + osl + 10))
@@ -96,6 +94,9 @@ logdir=${workdir}/slurm-${SLURM_JOB_ID}/benchmark-${isl}-${osl}
9694
mkdir -p ${logdir}
9795
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}
9896

97+
nsys_on=""
98+
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
99+
99100
echo "concurrency: ${concurrency}"
100101

101102
ctx_gpus=$((num_ctx_servers * ctx_tp_size * ctx_pp_size))

examples/disaggregated/slurm/benchmark/gen_worker_config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def gen_config_file(work_dir: str,
4848
server_port: Server port
4949
"""
5050
ctx_config = {
51+
'build_config': {
52+
'max_batch_size': ctx_batch_size,
53+
'max_num_tokens': ctx_max_num_tokens,
54+
'max_seq_len': ctx_max_seq_len,
55+
},
5156
'max_batch_size': ctx_batch_size,
5257
'max_num_tokens': ctx_max_num_tokens,
5358
'max_seq_len': ctx_max_seq_len,
@@ -79,6 +84,11 @@ def gen_config_file(work_dir: str,
7984
gen_moe_backend = "TRTLLM"
8085

8186
gen_config = {
87+
'build_config': {
88+
'max_batch_size': gen_batch_size,
89+
'max_num_tokens': gen_max_num_tokens,
90+
'max_seq_len': gen_max_seq_len,
91+
},
8292
'tensor_parallel_size': gen_tp_size,
8393
'moe_expert_parallel_size': gen_tp_size,
8494
'enable_attention_dp': True if gen_enable_attention_dp else False,

examples/disaggregated/slurm/benchmark/start_worker.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ else
6565
nsys_file=${nsys_folder}/nsys_worker_proc_${instance_id}_${SLURM_PROCID}
6666
export TLLM_PROFILE_RECORD_GC=1
6767
export TLLM_NVTX_DEBUG=1
68-
if [ "${role}" = "GEN" ]; then
68+
if [ "${role}" = "GEN" ] && [ "$SLURM_PROCID" = "0" ]; then
6969
export TLLM_PROFILE_START_STOP=200-250
7070
nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=none"
7171
echo "nsys_prefix: ${nsys_prefix}"
7272
elif [ "${role}" = "CTX" ]; then
7373
echo "nsys is not enabled on ctx_gpus"
7474
fi
75-
trtllm-llmapi-launch ${numa_bind_cmd} ${nsys_prefix} \
75+
${nsys_prefix} trtllm-llmapi-launch ${numa_bind_cmd} \
7676
trtllm-serve ${model_path} \
7777
--host $(hostname) --port ${port} \
7878
--extra_llm_api_options ${config_file}

tensorrt_llm/serve/scripts/benchmark_dataset.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,14 @@ def sample(
494494

495495
# Filter out sequences that are too long or too short
496496
requests = []
497-
for prompt, initial_prompt_len, cached_token_ids in zip(
498-
dataset, prompt_lengths, prompt_token_ids):
499-
i = len(requests)
500-
if i == num_requests:
501-
break
497+
dataset_len = len(dataset)
498+
499+
for i in range(num_requests):
500+
# Use modulo to cycle through the dataset when num_requests > dataset_len
501+
dataset_idx = i % dataset_len
502+
prompt = dataset[dataset_idx]
503+
initial_prompt_len = prompt_lengths[dataset_idx]
504+
cached_token_ids = prompt_token_ids[dataset_idx]
502505

503506
# Skip empty prompt
504507
if initial_prompt_len == 0:
@@ -534,9 +537,6 @@ def sample(
534537
prompt_len=total_input_len,
535538
expected_output_len=int(output_lens[i]),
536539
))
537-
assert len(requests) == num_requests, (
538-
f"Only {len(requests)} requests sampled from sharegpt dataset, {num_requests} requests are needed"
539-
)
540540
else:
541541
for i in range(num_requests):
542542
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
@@ -1131,6 +1131,7 @@ def sample(
11311131
if parser_fn is None:
11321132
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
11331133

1134+
sampled_requests = []
11341135
for item in self.data:
11351136
if len(prompts) >= num_requests:
11361137
break

0 commit comments

Comments
 (0)