@@ -183,6 +183,8 @@ def adjust_batch_dims_for_expert_parallelism(
183183 local_batch_dims .token_count ,
184184 int (is_non_decode ),
185185 int (has_explicit_chunked_prefill_req ),
186+ local_batch_dims .prefill_req_count ,
187+ local_batch_dims .decode_req_count ,
186188 ],
187189 dtype = torch .int32 ,
188190 device = torch .cuda .current_device (),
@@ -208,10 +210,21 @@ def adjust_batch_dims_for_expert_parallelism(
208210 return None # indicate no match, run in eager mode
209211
210212 assert not has_explicit_chunked_prefill_req
213+
214+ # If strict matching is enabled, we sync the request counts across EP ranks
215+ # to ensure the graph captures the maximum needed capacity.
216+ # TODO(ksanthanam): Add functional test for this scenario
217+ adjusted_prefill_req_count = (
218+ int (sync_tensor [3 ].item ()) if strict else local_batch_dims .prefill_req_count
219+ )
220+ adjusted_decode_req_count = (
221+ int (sync_tensor [4 ].item ()) if strict else local_batch_dims .decode_req_count
222+ )
223+
211224 adjusted_batch_dim = InferenceBatchDimensions (
212225 token_count = int (sync_tensor [0 ].item ()),
213- prefill_req_count = local_batch_dims . prefill_req_count ,
214- decode_req_count = local_batch_dims . decode_req_count ,
226+ prefill_req_count = adjusted_prefill_req_count ,
227+ decode_req_count = adjusted_decode_req_count ,
215228 has_explicit_chunked_prefill_req = False ,
216229 )
217230 return adjusted_batch_dim
0 commit comments