Skip to content

Commit 1fdb29f

Browse files
authored
Synchronize the request counts for EP inference with strict matching (#3033)
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
1 parent bc2eb9a commit 1fdb29f

File tree

6 files changed

+52
-29
lines changed

6 files changed

+52
-29
lines changed

megatron/core/inference/batch_dimensions_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ def adjust_batch_dims_for_expert_parallelism(
183183
local_batch_dims.token_count,
184184
int(is_non_decode),
185185
int(has_explicit_chunked_prefill_req),
186+
local_batch_dims.prefill_req_count,
187+
local_batch_dims.decode_req_count,
186188
],
187189
dtype=torch.int32,
188190
device=torch.cuda.current_device(),
@@ -208,10 +210,21 @@ def adjust_batch_dims_for_expert_parallelism(
208210
return None # indicate no match, run in eager mode
209211

210212
assert not has_explicit_chunked_prefill_req
213+
214+
# If strict matching is enabled, we sync the request counts across EP ranks
215+
# to ensure the graph captures the maximum needed capacity.
216+
# TODO(ksanthanam): Add functional test for this scenario
217+
adjusted_prefill_req_count = (
218+
int(sync_tensor[3].item()) if strict else local_batch_dims.prefill_req_count
219+
)
220+
adjusted_decode_req_count = (
221+
int(sync_tensor[4].item()) if strict else local_batch_dims.decode_req_count
222+
)
223+
211224
adjusted_batch_dim = InferenceBatchDimensions(
212225
token_count=int(sync_tensor[0].item()),
213-
prefill_req_count=local_batch_dims.prefill_req_count,
214-
decode_req_count=local_batch_dims.decode_req_count,
226+
prefill_req_count=adjusted_prefill_req_count,
227+
decode_req_count=adjusted_decode_req_count,
215228
has_explicit_chunked_prefill_req=False,
216229
)
217230
return adjusted_batch_dim

megatron/core/inference/engines/dynamic_engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,7 +1706,11 @@ async def run_engine_with_coordinator(
17061706
if ep_group_has_work and local_pending_requests == 0:
17071707
# run dummy forward pass if EP group as a whole has work,
17081708
# but this rank does not have any work.
1709+
self.step_start_event.record()
17091710
self.controller.dummy_forward()
1711+
self.step_end_event.record()
1712+
self.step_end_event.synchronize()
1713+
self.step_count += 1
17101714
continue
17111715

17121716
# 3. No work in EP group

megatron/core/ssm/mamba_block.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,37 @@ def mamba_state_shapes_per_request(self) -> Optional[Tuple[Tuple[int], Tuple[int
202202
return layer.mamba_state_shapes_per_request()
203203
return None
204204

205+
def _should_call_local_cudagraph(self, *args, **kwargs):
206+
"""
207+
Check if we should call the local cudagraph path.
208+
"""
209+
if not self.training and (
210+
hasattr(self, 'cudagraph_manager')
211+
and kwargs['attention_mask'] is None
212+
and (
213+
kwargs.get('inference_context') is not None
214+
or kwargs.get('inference_params') is not None
215+
)
216+
and CudaGraphScope.full_iteration in self.config.cuda_graph_scope
217+
):
218+
if kwargs['inference_context'].is_static_batching():
219+
using_cuda_graph = kwargs['inference_context'].is_decode_only()
220+
else:
221+
using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step()
222+
223+
if using_cuda_graph:
224+
return True
225+
return False
226+
227+
def __call__(self, *args, **kwargs):
228+
if self._should_call_local_cudagraph(*args, **kwargs):
229+
kwargs['hidden_states'] = (
230+
kwargs['hidden_states'].unwrap()
231+
if isinstance(kwargs['hidden_states'], WrappedTensor)
232+
else kwargs['hidden_states']
233+
)
234+
return super().__call__(*args, **kwargs)
235+
205236
def forward(
206237
self,
207238
hidden_states: Union[Tensor, WrappedTensor],

megatron/core/ssm/mamba_layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
192192
hasattr(self, 'cudagraph_manager')
193193
and kwargs.get('attention_mask') is None
194194
and kwargs.get('inference_context') is not None
195+
and CudaGraphScope.full_iteration not in self.config.cuda_graph_scope
195196
):
196197
using_cuda_graph = kwargs['inference_context'].using_cuda_graph_this_step()
197198
return using_cuda_graph

megatron/core/transformer/transformer_block.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -606,16 +606,7 @@ def __call__(self, *args, **kwargs):
606606
if isinstance(kwargs['hidden_states'], WrappedTensor)
607607
else kwargs['hidden_states']
608608
)
609-
# dynamic_inference_decode_only is not a real argument to forward, it is only used
610-
# to differentiate the cuda graph used for decode from the one used for non-decode
611-
# inference.
612-
dynamic_inference_decode_only = kwargs['inference_context'].is_decode_only()
613-
# cudagraphmanager returns a singleton tuple, whereas the
614-
# normal forward returns a tensor, therefore we need
615-
# to extract the tensor from the tuple
616-
return super().__call__(
617-
*args, dynamic_inference_decode_only=dynamic_inference_decode_only, **kwargs
618-
)[0]
609+
return super().__call__(*args, **kwargs)[0]
619610
return super().__call__(*args, **kwargs)
620611

621612
def forward(

megatron/core/transformer/transformer_layer.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -506,10 +506,6 @@ def forward(self, *args, **kwargs):
506506
This method calls the core computation of a transformer layer, including
507507
self-attention, cross-attention (if applicable), and feed-forward operations.
508508
"""
509-
# Remove 'dynamic_inference_decode_only' from kwargs if present
510-
# this is only used to uniquely identify decode and non-decode cuda graph
511-
# runners in the cuda graph manager
512-
kwargs.pop("dynamic_inference_decode_only", None)
513509
hidden_states, context = self._forward_attention(*args, **kwargs)
514510
output = self._forward_mlp(
515511
hidden_states,
@@ -1203,19 +1199,6 @@ def _should_call_local_cudagraph(self, *args, **kwargs):
12031199
return True
12041200
return False
12051201

1206-
def __call__(self, *args, **kwargs):
1207-
if self._should_call_local_cudagraph(*args, **kwargs):
1208-
# Inference mode.
1209-
if kwargs.get('inference_context') is not None:
1210-
# dynamic_inference_decode_only is not a real argument to forward, it is only used
1211-
# to differentiate the cuda graph used for decode from the one used for non-decode
1212-
# inference.
1213-
kwargs["dynamic_inference_decode_only"] = kwargs[
1214-
'inference_context'
1215-
].is_decode_only()
1216-
1217-
return super().__call__(*args, **kwargs)
1218-
12191202
def get_layer_norm_weights(self):
12201203
"""
12211204
Get the weights of all layernorms (attention and MLP) in the transformer layer.

0 commit comments

Comments
 (0)