Skip to content

Commit 096dbeb

Browse files
authored
Ensure that last prefill chunk is handled correctly by Mamba models (#2897)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
1 parent ba456fd commit 096dbeb

File tree

10 files changed

+5712
-34
lines changed

10 files changed

+5712
-34
lines changed

megatron/core/inference/contexts/attention_context/mamba_metadata.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def update(
154154
active_mamba_indices[:real_decode_count]
155155
)
156156
if padded_decode_count > real_decode_count:
157-
self._batch_indices_decode_buffer[real_decode_count:padded_decode_count].fill_(-1)
157+
self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1
158158
self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count]
159159

160160
# Determine if we have a chunked prefill request and adjust counts for regular prefill
@@ -180,9 +180,7 @@ def update(
180180
)
181181

182182
if padded_prefill_count > regular_prefill_count:
183-
self._batch_indices_prefill_buffer[
184-
regular_prefill_count:padded_prefill_count
185-
].fill_(-1)
183+
self._batch_indices_prefill_buffer[regular_prefill_count:padded_prefill_count] = -1
186184

187185
self.batch_indices_prefill = self._batch_indices_prefill_buffer[:padded_prefill_count]
188186

@@ -199,7 +197,7 @@ def update(
199197
)
200198

201199
if padded_token_count > seq_len:
202-
self._seq_idx_buffer[:, seq_len:padded_token_count].fill_(-1)
200+
self._seq_idx_buffer[:, seq_len:padded_token_count] = -1
203201
self.seq_idx = self._seq_idx_buffer[:, :padded_token_count]
204202

205203
# Update cu_seqlens

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
self.use_cuda_graphs_for_non_decode_steps = use_cuda_graphs_for_non_decode_steps
541541
# Deal with chunked prefill
542542
self.chunked_prefill_request_id = -1
543+
self.has_explicit_chunked_prefill_req = False
543544

544545
# FlashInfer.
545546
if use_flashinfer_fused_rope is True:
@@ -1300,15 +1301,11 @@ def initialize_attention_state(
13001301
if construct_graph_dimensions is not None:
13011302
self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions)
13021303

1303-
has_explicit_chunked_prefill_req = (
1304-
self.chunked_prefill_request_id != -1 and self.is_hybrid_model
1305-
)
1306-
13071304
batch_dimensions = InferenceBatchDimensions(
13081305
token_count=self.active_token_count,
13091306
prefill_req_count=self.num_prefill_requests,
13101307
decode_req_count=self.num_decode_requests,
1311-
has_explicit_chunked_prefill_req=has_explicit_chunked_prefill_req,
1308+
has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req,
13121309
)
13131310
self.batch_dimensions = batch_dimensions
13141311
best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config(
@@ -1342,7 +1339,7 @@ def initialize_attention_state(
13421339
token_count=padded_token_count,
13431340
prefill_req_count=padded_prefill_req_count,
13441341
decode_req_count=padded_decode_req_count,
1345-
has_explicit_chunked_prefill_req=has_explicit_chunked_prefill_req,
1342+
has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req,
13461343
)
13471344
self.padded_active_token_count = self.padded_batch_dimensions.token_count
13481345
self.padded_active_request_count = self.padded_batch_dimensions.req_count
@@ -1373,6 +1370,8 @@ def initialize_attention_state(
13731370

13741371
attn_dimensions = batch_dimensions
13751372
if self.using_cuda_graph_this_step():
1373+
assert not self.has_explicit_chunked_prefill_req
1374+
13761375
# Treat some decode requests as prefill requests to fit the cuda graph batch dimension.
13771376
if batch_dimensions.decode_req_count > self.padded_batch_dimensions.decode_req_count:
13781377
total_req = batch_dimensions.req_count
@@ -1382,7 +1381,7 @@ def initialize_attention_state(
13821381
token_count=batch_dimensions.token_count,
13831382
prefill_req_count=adjusted_prefill_req_count,
13841383
decode_req_count=adjusted_decode_req_count,
1385-
has_explicit_chunked_prefill_req=has_explicit_chunked_prefill_req,
1384+
has_explicit_chunked_prefill_req=False,
13861385
)
13871386

13881387
self.active_attn_metadata["mha_metadata"].update(
@@ -1461,6 +1460,7 @@ def reset(self) -> None:
14611460

14621461
# Reset chunked prefill state
14631462
self.chunked_prefill_request_id = -1
1463+
self.has_explicit_chunked_prefill_req = False
14641464
self.num_prefill_requests = 0
14651465
self._using_cuda_graph_this_step = False
14661466
self.padded_batch_dimensions = InferenceBatchDimensions(
@@ -1981,6 +1981,7 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
19811981
active_requests_mask[-1] = (
19821982
1 # must keep this, next iteration will add a new chunk to it
19831983
)
1984+
self.has_explicit_chunked_prefill_req = False
19841985

19851986
active_request_count = (active_requests_mask == 1).sum().item()
19861987
finished_request_count = (active_requests_mask == 0).sum().item()
@@ -2011,7 +2012,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
20112012

20122013
# Reset Mamba state.
20132014
self.reset_mamba_state()
2014-
20152015
return
20162016

20172017
# 3. Concatenate the paused tokens to the active tokens if present.
@@ -2070,9 +2070,9 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T
20702070

20712071
if self.chunked_prefill_request_id != -1:
20722072
# find the id in request_ids that is the chunked_prefill_request_id. Only one request should be chunked.
2073-
active_requests_requiring_new_block[self.get_index_of_chunked_prefill_request()] = (
2074-
0 # chunked prefill should not be paused
2075-
)
2073+
active_requests_requiring_new_block[
2074+
self.get_index_of_chunked_prefill_request() - self.paused_request_count
2075+
] = 0 # chunked prefill should not be paused
20762076

20772077
active_requests_requiring_new_block_count = (
20782078
(active_requests_requiring_new_block == 1).sum().item()

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,12 @@ def schedule_chunked_prefill(self):
10811081

10821082
if request_can_be_added and kv_cache_available:
10831083
if token_fully_can_be_added:
1084+
# For Mamba models we need to ensure that the last prefill chunk
1085+
# is still tagged as a chunked prefill request.
1086+
self.context.has_explicit_chunked_prefill_req = (
1087+
self.context.is_hybrid_model
1088+
and self.context.chunked_prefill_request_id == req.request_id
1089+
)
10841090
self.context.chunked_prefill_request_id = -1
10851091
self.context.add_request(req)
10861092
self._loop.call_soon_threadsafe(
@@ -1091,14 +1097,18 @@ def schedule_chunked_prefill(self):
10911097
# Fully scheduled, so we remove from waiting pool
10921098
self.waiting_request_ids.popleft()
10931099
# Only this case we keep checking the rest of the waiting queue
1094-
can_schedule = True
1100+
# We break early for Mamba models running a final prefill chunk
1101+
# so that no additional requests are scheduled beyond the chunked
1102+
# prefill request.
1103+
can_schedule = not self.context.has_explicit_chunked_prefill_req
10951104
elif token_partially_can_be_added:
10961105
chunk_length = self.context.max_tokens - self.context.active_token_count
10971106
self.context.add_request(req, chunk_length=chunk_length)
10981107
self._loop.call_soon_threadsafe(
10991108
self._loop.create_task, self._notify_cond_for_new_request()
11001109
)
11011110
self.context.chunked_prefill_request_id = req.request_id
1111+
self.context.has_explicit_chunked_prefill_req = self.context.is_hybrid_model
11021112
req.remaining_prompt_tokens = req.remaining_prompt_tokens[chunk_length:]
11031113
req.finished_chunk_token_count += chunk_length
11041114
# Still have tokens to prefill, so we break and keep the

megatron/core/transformer/attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def flash_decode_and_prefill(
658658
cu_seqlens_k,
659659
seqlens_k,
660660
block_table,
661+
is_decode_only,
661662
) -> Tensor:
662663
"""Flash attention kernel for mixed decode and prefill samples.
663664
@@ -671,6 +672,7 @@ def flash_decode_and_prefill(
671672
cu_seqlens_k (Tensor): Cumulative key sequence lengths.
672673
seqlens_k (Tensor): key sequence lengths.
673674
block_table (Tensor): KV cache block ids for all samples.
675+
is_decode_only (bool): True if batch is decode only.
674676
Return:
675677
(Tensor) Attention output.
676678
"""
@@ -679,7 +681,7 @@ def flash_decode_and_prefill(
679681
assert block_table is not None
680682

681683
# Flash attn kernel.
682-
if max_seqlen_q > 1:
684+
if not is_decode_only:
683685
q = q.squeeze(1)
684686
if getattr(self, "softmax_scale", None) is not None:
685687
softmax_scale = self.softmax_scale
@@ -1065,6 +1067,7 @@ def forward(
10651067
cu_kv_lengths,
10661068
kv_lengths,
10671069
block_table,
1070+
inference_context.is_decode_only(),
10681071
)
10691072
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')
10701073

megatron/training/arguments.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,8 +1686,12 @@ def _add_inference_args(parser):
16861686
group.add_argument('--mlp-chunks-for-prefill', type=int, default=1,
16871687
help='Number of chunks along sequence dimension for MLP '
16881688
'computation during prefill')
1689-
group.add_argument('--disable-chunked-prefill', default=False, action="store_true",
1690-
help='Disable chunked prefill (chunked prefill is enabled by default).')
1689+
# TODO(ksanthanam): Clean this up in future PR
1690+
group.add_argument('--enable-chunked-prefill', dest='disable_chunked_prefill',
1691+
action='store_false', default=True,
1692+
help="Enable chunked prefill (disabled by default)")
1693+
group.add_argument('--disable-chunked-prefill', dest='disable_chunked_prefill',
1694+
action='store_true', help=argparse.SUPPRESS)
16911695
group.add_argument('--inference-dynamic-batching-cuda-graph-max-tokens',
16921696
type=int, default=16384,
16931697
help='Maximum number of tokens to capture in a cuda graph.')

0 commit comments

Comments
 (0)