@@ -593,8 +593,8 @@ def _ubatch_split(
593
593
if not self .parallel_config .enable_microbatching :
594
594
return (None , 0 , None )
595
595
596
+ # Check preconditions for microbatching
596
597
total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
597
- num_reqs = self .input_batch .num_reqs
598
598
should_attempt_ubatching = \
599
599
self .parallel_config .enable_microbatching and \
600
600
total_num_scheduled_tokens >= \
@@ -610,11 +610,13 @@ def _ubatch_split(
610
610
if not should_ubatch :
611
611
return (None , 0 , None )
612
612
613
- # This doesn't actually pad the ubatch slices. It just initialize the
614
- # split point to the correct value so that padding can be applied
613
+ # This doesn't actually pad the ubatch slices. It just initializes the
614
+ # split point to the padded value so that padding can be applied
615
615
# to the second ubatch in pad_out_ubatch_slice after attention
616
616
# metadata creation
617
- assert num_pad_tokens < total_num_scheduled_tokens , f"num_pad_tokens { num_pad_tokens } original_num_tokens { total_num_scheduled_tokens } "
617
+ assert num_pad_tokens < total_num_scheduled_tokens ,\
618
+ f"num_pad_tokens { num_pad_tokens } " \
619
+ f"original_num_tokens { total_num_scheduled_tokens } "
618
620
total_num_tokens_per_ubatch = (total_num_scheduled_tokens +
619
621
num_pad_tokens ) // 2
620
622
padded_first_ubatch_slice = slice (0 , total_num_tokens_per_ubatch )
@@ -2945,14 +2947,17 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
2945
2947
"decode" if uniform_decode else "mixed prefill-decode" ,
2946
2948
cudagraph_runtime_mode .name ))
2947
2949
enable_microbatching = self .parallel_config .enable_microbatching
2948
- # We skip EPLB here since we don't want to record dummy metrics
2950
+ # DBO Only supports running Full cudagraphs with uniform
2951
+ # decode lengths
2949
2952
if enable_microbatching and uniform_decode :
2950
2953
for num_tokens in compilation_cases :
2951
2954
# If the number of tokens is greater than the microbatching
2952
2955
# threshold, don't generate a microbatched cudagraph
2953
2956
if (num_tokens
2954
2957
< self .parallel_config .microbatching_token_threshold ):
2955
2958
continue
2959
+
2960
+ # Warmup
2956
2961
for _ in range (
2957
2962
self .compilation_config .cudagraph_num_of_warmups ):
2958
2963
force_attention = (
@@ -2963,13 +2968,14 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
2963
2968
uniform_decode = True ,
2964
2969
allow_microbatching = True ,
2965
2970
skip_eplb = True )
2966
- # DBO Only supports running with Full cudagraphs with uniform
2967
- # decode lengths
2971
+
2972
+ # Graph Capture
2968
2973
self ._dummy_run (num_tokens ,
2969
2974
cudagraph_runtime_mode = CUDAGraphMode .FULL ,
2970
2975
uniform_decode = True ,
2971
2976
allow_microbatching = True ,
2972
2977
skip_eplb = True )
2978
+ # We skip EPLB here since we don't want to record dummy metrics
2973
2979
for num_tokens in compilation_cases :
2974
2980
for _ in range (self .compilation_config .cudagraph_num_of_warmups ):
2975
2981
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
0 commit comments