Skip to content

Commit 413d9b8

Browse files
pcastonguayfredricz-20070104
authored andcommitted
[https://nvbugs/5552889][fix] fix: Prevent empty batch when using attention DP with disagg (NVIDIA#8372)
Signed-off-by: Patrice Castonguay <[email protected]> Signed-off-by: Mike Iovine <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
1 parent af165a7 commit 413d9b8

File tree

4 files changed

+145
-20
lines changed

4 files changed

+145
-20
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -821,12 +821,7 @@ def _executor_loop_pp(self):
821821
f'{len(scheduled_batch.generation_requests)} generation requests'
822822
)
823823

824-
if self.enable_attention_dp:
825-
tp_batch_sizes = self.dist.tp_allgather(
826-
scheduled_batch.batch_size)
827-
can_queue = 0 not in tp_batch_sizes
828-
else:
829-
can_queue = scheduled_batch.batch_size > 0
824+
can_queue = self._can_queue(scheduled_batch)
830825

831826
if not can_queue:
832827
self.micro_batches[microbatch_id] = None
@@ -1004,6 +999,16 @@ def wait_on_pp_send_handles(self, microbatch_id):
1004999
self.send_handles[microbatch_id].wait()
10051000
self.send_handles[microbatch_id] = None
10061001

1002+
def _can_queue(self, scheduled_batch):
1003+
1004+
if self.enable_attention_dp:
1005+
tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size)
1006+
can_queue = 0 not in tp_batch_sizes
1007+
else:
1008+
can_queue = scheduled_batch.batch_size > 0
1009+
1010+
return can_queue
1011+
10071012
def _prepare_and_schedule_batch(self):
10081013
new_requests = self._fetch_and_activate_new_requests()
10091014
if self.should_stop_processing:
@@ -1126,8 +1131,8 @@ def _executor_loop(self):
11261131

11271132
finished_requests = []
11281133

1129-
if scheduled_batch.batch_size > 0 or (
1130-
self.enable_attention_dp and self.dist.tp_size > 1):
1134+
can_queue = self._can_queue(scheduled_batch)
1135+
if can_queue:
11311136
if self.kv_cache_transceiver:
11321137
# For generation requests which have completed KV cache transfer
11331138
self._prepare_disagg_gen_transmission_complete(
@@ -1139,8 +1144,11 @@ def _executor_loop(self):
11391144

11401145
self._kv_connector_start_batch(scheduled_batch)
11411146

1142-
if scheduled_batch.batch_size > 0 or (
1143-
self.enable_attention_dp and self.dist.tp_size > 1):
1147+
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
1148+
if self.kv_connector_manager:
1149+
can_queue = self._can_queue(scheduled_batch)
1150+
1151+
if can_queue:
11441152
# init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers.
11451153
# init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated.
11461154
if self.guided_decoder is not None:
@@ -1298,7 +1306,8 @@ def _executor_loop_overlap(self):
12981306

12991307
self._pause_requests(scheduled_batch.paused_requests)
13001308

1301-
if scheduled_batch.batch_size > 0:
1309+
can_queue = self._can_queue(scheduled_batch)
1310+
if can_queue:
13021311
if self.kv_cache_transceiver:
13031312
# For generation requests which have completed KV cache transfer
13041313
self._prepare_disagg_gen_transmission_complete(
@@ -1307,7 +1316,11 @@ def _executor_loop_overlap(self):
13071316

13081317
self._kv_connector_start_batch(scheduled_batch)
13091318

1310-
if scheduled_batch.batch_size > 0:
1319+
# if using a kv connector, we need to call can_queue again since scheduled_batch might have changed
1320+
if self.kv_connector_manager:
1321+
can_queue = self._can_queue(scheduled_batch)
1322+
1323+
if can_queue:
13111324

13121325
# The generation requests that are do not have batch_idx,
13131326
# needs to be in front of the batch due to the assumptions
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
hostname: localhost
2+
port: 8000
3+
model: DeepSeek-V3-Lite/bf16
4+
backend: "pytorch"
5+
context_servers:
6+
num_instances: 1
7+
build_config:
8+
max_batch_size: 10
9+
max_num_tokens: 512
10+
max_seq_len: 768
11+
max_batch_size: 10
12+
max_num_tokens: 512
13+
max_seq_len: 768
14+
tensor_parallel_size: 2
15+
moe_expert_parallel_size: 2
16+
enable_attention_dp: true
17+
pipeline_parallel_size: 1
18+
print_iter_log: true
19+
cuda_graph_config: null
20+
disable_overlap_scheduler: true
21+
kv_cache_config:
22+
enable_block_reuse: false
23+
free_gpu_memory_fraction: 0.05
24+
max_tokens: 512
25+
cache_transceiver_config:
26+
max_tokens_in_buffer: 8448
27+
backend: DEFAULT
28+
urls:
29+
- "localhost:8001"
30+
generation_servers:
31+
num_instances: 1
32+
build_config:
33+
max_batch_size: 1
34+
max_num_tokens: 2048
35+
max_seq_len: 2560
36+
tensor_parallel_size: 1
37+
moe_expert_parallel_size: 1
38+
enable_attention_dp: false
39+
enable_lm_head_tp_in_adp: false
40+
pipeline_parallel_size: 1
41+
max_batch_size: 1
42+
max_num_tokens: 2048
43+
max_seq_len: 2560
44+
cuda_graph_config:
45+
enable_padding: true
46+
batch_sizes:
47+
- 1
48+
print_iter_log: true
49+
kv_cache_config:
50+
enable_block_reuse: false
51+
free_gpu_memory_fraction: 0.7
52+
max_tokens: 2560
53+
moe_config:
54+
backend: CUTLASS
55+
cache_transceiver_config:
56+
max_tokens_in_buffer: 8448
57+
backend: DEFAULT
58+
stream_interval: 1
59+
num_postprocess_workers: 1
60+
urls:
61+
- "localhost:8002"

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,10 @@ def get_test_config(test_desc, example_dir, test_root):
261261
(4,
262262
f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml"
263263
),
264+
"deepseek_v3_lite_bf16_empty_batch":
265+
(3,
266+
f"{test_configs_root}/disagg_config_deepseek_v3_lite_empty_batch.yaml"
267+
),
264268
}
265269

266270
if test_desc not in config_map:
@@ -1530,14 +1534,19 @@ def run_disaggregated_benchmark(example_dir,
15301534
benchmark_model_root,
15311535
shared_gpt_path,
15321536
env=None,
1533-
cwd=None):
1537+
cwd=None,
1538+
num_ranks=2,
1539+
random_input_len=16,
1540+
random_output_len=64,
1541+
num_prompts=100,
1542+
max_concurrency=32,
1543+
skip_warmup=False):
15341544
"""Run disaggregated test with given configuration."""
15351545
run_env = env.copy()
15361546
run_env["UCX_TLS"] = "^ib"
1537-
num_rank = 2
15381547
workers_cmd = [
15391548
'mpirun', '--allow-run-as-root', '--oversubscribe', '-n',
1540-
str(num_rank), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
1549+
str(num_ranks), 'trtllm-serve', 'disaggregated_mpi_worker', '-c',
15411550
config_file
15421551
]
15431552

@@ -1589,15 +1598,15 @@ def run_disaggregated_benchmark(example_dir,
15891598
'--dataset-path',
15901599
shared_gpt_path,
15911600
'--random-input-len',
1592-
'256',
1601+
str(random_input_len),
15931602
'--random-output-len',
1594-
'64',
1603+
str(random_output_len),
15951604
'--random-prefix-len',
15961605
'0',
15971606
'--num-prompts',
1598-
'320',
1607+
str(num_prompts),
15991608
'--max-concurrency',
1600-
'32',
1609+
str(max_concurrency),
16011610
'--host',
16021611
'localhost',
16031612
'--port',
@@ -1608,7 +1617,8 @@ def run_disaggregated_benchmark(example_dir,
16081617
'e2el,ttft',
16091618
]
16101619
# warm up
1611-
check_call(benchmark_cmd, env=env)
1620+
if not skip_warmup:
1621+
check_call(benchmark_cmd, env=env)
16121622
output = check_output(benchmark_cmd, env=env)
16131623
e2el_pattern = r"Median E2EL \(ms\):\s*(\d+\.?\d*)"
16141624
ttft_pattern = r"Median TTFT \(ms\):\s*(\d+\.?\d*)"
@@ -1718,3 +1728,43 @@ def test_disaggregated_benchmark_on_diff_backends(
17181728

17191729
assert ucx_e2el > 0 and nixl_e2el > 0 and nixl_e2el < 1.05 * ucx_e2el
17201730
assert ucx_ttft > 0 and nixl_ttft > 0 and nixl_ttft < 1.05 * ucx_ttft
1731+
1732+
1733+
@pytest.mark.parametrize("benchmark_model_root", ['DeepSeek-V3-Lite-bf16'],
1734+
indirect=True)
1735+
def test_disaggregated_deepseek_v3_lite_bf16_empty_batch(
1736+
disaggregated_example_root, llm_venv, benchmark_model_root,
1737+
benchmark_root, shared_gpt_path):
1738+
1739+
src_dst_dict = {
1740+
benchmark_model_root:
1741+
f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/bf16",
1742+
}
1743+
for src, dst in src_dst_dict.items():
1744+
if not os.path.islink(dst):
1745+
os.makedirs(os.path.dirname(dst), exist_ok=True)
1746+
os.symlink(src, dst, target_is_directory=True)
1747+
1748+
test_desc = "deepseek_v3_lite_bf16_empty_batch"
1749+
num_ranks, config_file = get_test_config(test_desc,
1750+
disaggregated_example_root,
1751+
os.path.dirname(__file__))
1752+
1753+
env = llm_venv._new_env.copy()
1754+
e2el, ttft = run_disaggregated_benchmark(
1755+
disaggregated_example_root,
1756+
config_file,
1757+
benchmark_root,
1758+
benchmark_model_root,
1759+
shared_gpt_path,
1760+
env=env,
1761+
cwd=llm_venv.get_working_directory(),
1762+
num_ranks=num_ranks,
1763+
num_prompts=10,
1764+
max_concurrency=10,
1765+
random_input_len=384,
1766+
random_output_len=1536,
1767+
skip_warmup=True)
1768+
print(f"E2EL: {e2el} ms, TTFT: {ttft} ms")
1769+
1770+
assert e2el > 0 and ttft > 0

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ l0_dgx_h100:
158158
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_overlap_cuda_graph[DeepSeek-V3-Lite-fp8]
159159
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16]
160160
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16]
161+
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_empty_batch[DeepSeek-V3-Lite-bf16]
161162
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
162163
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8]
163164
- disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16]

0 commit comments

Comments
 (0)