@@ -1111,17 +1111,30 @@ def sync_and_slice_intermediate_tensors(
11111111 for k , v in self .intermediate_tensors .items ()
11121112 })
11131113
1114- def get_dp_padding (self , num_tokens : int ):
1114+ def get_dp_padding (self ,
1115+ num_tokens : int ) -> tuple [int , Optional [torch .Tensor ]]:
11151116 dp_size = self .vllm_config .parallel_config .data_parallel_size
11161117 dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1117- if dp_size == 1 :
1118+
1119+ # For DP: Don't pad when setting enforce_eager.
1120+ # This lets us set enforce_eager on the prefiller in a P/D setup and
1121+ # still use CUDA graphs (enabled by this padding) on the decoder.
1122+ #
1123+ # TODO(tms) : There are many cases where padding is enabled for
1124+ # prefills, causing unnecessary and excessive padding of activations.
1125+
1126+ if dp_size == 1 or self .vllm_config .model_config .enforce_eager :
11181127 # Early exit.
1119- return 0
1128+ return 0 , None
11201129
11211130 num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
11221131 num_tokens , dp_size , dp_rank )
11231132 max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp ).item ()
1124- return max_tokens_across_dp_cpu - num_tokens
1133+ num_tokens_after_padding = torch .tensor ([max_tokens_across_dp_cpu ] *
1134+ dp_size ,
1135+ device = "cpu" ,
1136+ dtype = torch .int32 )
1137+ return max_tokens_across_dp_cpu - num_tokens , num_tokens_after_padding
11251138
11261139 @torch .inference_mode ()
11271140 def execute_model (
@@ -1161,7 +1174,8 @@ def execute_model(
11611174 num_input_tokens = num_scheduled_tokens
11621175
11631176 # Padding for DP
1164- num_input_tokens += self .get_dp_padding (num_input_tokens )
1177+ num_pad , num_tokens_across_dp = self .get_dp_padding (num_input_tokens )
1178+ num_input_tokens += num_pad
11651179
11661180 # _prepare_inputs may reorder the batch, so we must gather multi
11671181 # modal outputs after that to ensure the correct order
@@ -1208,7 +1222,8 @@ def execute_model(
12081222 # Use persistent buffers for CUDA graphs.
12091223 with set_forward_context (attn_metadata ,
12101224 self .vllm_config ,
1211- num_tokens = num_input_tokens ):
1225+ num_tokens = num_input_tokens ,
1226+ num_tokens_across_dp = num_tokens_across_dp ):
12121227 self .maybe_setup_kv_connector (scheduler_output )
12131228
12141229 model_output = self .model (
@@ -1681,7 +1696,8 @@ def _dummy_run(
16811696 ) -> torch .Tensor :
16821697
16831698 # Padding for DP
1684- num_tokens += self .get_dp_padding (num_tokens )
1699+ num_pad , num_tokens_across_dp = self .get_dp_padding (num_tokens )
1700+ num_tokens += num_pad
16851701
16861702 # Set num_scheduled_tokens based on num_tokens and max_num_seqs
16871703 # for dummy run with LoRA so that the num_reqs collectively
@@ -1747,9 +1763,11 @@ def _dummy_run(
17471763 intermediate_tensors = self .sync_and_slice_intermediate_tensors (
17481764 num_tokens , None , False )
17491765
1750- with set_forward_context (attn_metadata ,
1751- self .vllm_config ,
1752- num_tokens = num_tokens ):
1766+ with set_forward_context (
1767+ attn_metadata ,
1768+ self .vllm_config ,
1769+ num_tokens = num_tokens ,
1770+ num_tokens_across_dp = num_tokens_across_dp ):
17531771 outputs = model (
17541772 input_ids = input_ids ,
17551773 positions = positions ,
0 commit comments