4848 get_spec_metadata ,
4949 update_spec_config_from_model_config )
5050from ..speculative .drafting_loops import BaseDraftingLoopWrapper
51- from ..speculative .eagle3 import Eagle3ResourceManager , Eagle3SpecMetadata
51+ from ..speculative .eagle3 import (Eagle3OneModelSpecMetadata ,
52+ Eagle3ResourceManager , Eagle3SpecMetadata )
5253from ..speculative .mtp import SampleStateTensorsMTP
5354from ..speculative .utils import SpecDecodingTensor
5455from ..utils import (get_model_extra_attrs ,
@@ -426,6 +427,7 @@ def __init__(
426427 mapping = self .mapping ,
427428 dist = self .dist ,
428429 kv_cache_manager_key = self .kv_cache_manager_key ,
430+ sparse_attention_config = self .sparse_attention_config ,
429431 )
430432 self .cuda_graph_runner = CUDAGraphRunner (cuda_graph_runner_config )
431433
@@ -568,13 +570,12 @@ def warmup(self, resource_manager: ResourceManager) -> None:
568570 # Reset the global cuda graph dummy request to None in warmup.
569571 self .cuda_graph_runner .padding_dummy_request = None
570572
571- cp_type = self .mapping .cp_config .get ('cp_type' , None )
572- if cp_type is not None :
573- if cp_type in [CpType .ULYSSES , CpType .STAR ]:
574- logger .info (
575- "[ModelEngine::warmup] Skipping warmup for cp_type: " ,
576- cp_type .name )
577- return
573+ if self .mapping .cp_size > 1 :
574+ cp_type = self .mapping .cp_config .get ("cp_type" , None )
575+ logger .info (
576+ f"[ModelEngine::warmup] Skipping warmup for cp_type: { None if cp_type is None else cp_type .name } ."
577+ )
578+ return
578579
579580 self ._run_torch_compile_warmup (resource_manager )
580581 self ._run_autotuner_warmup (resource_manager )
@@ -625,7 +626,7 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
625626 """Runs a forward pass to populate the autotuner cache."""
626627 if not self .llm_args .enable_autotuner :
627628 return
628-
629+ AutoTuner . get (). setup_distributed_state ( self . mapping , self . dist )
629630 logger .info ("Running autotuner warmup..." )
630631 kv_cache_manager = resource_manager .get_resource_manager (
631632 self .kv_cache_manager_key )
@@ -635,8 +636,7 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
635636 self .batch_size * (self .max_seq_len - 1 ))
636637
637638 cache_path = os .environ .get ("TLLM_AUTOTUNER_CACHE_PATH" , None )
638- with self .no_cuda_graph (), autotune (cache_path = cache_path ,
639- rank = self .mapping .rank ):
639+ with self .no_cuda_graph (), autotune (cache_path = cache_path ):
640640 warmup_request = self ._create_warmup_request (
641641 resource_manager , curr_max_num_tokens , 0 )
642642 with self ._release_batch_context (warmup_request ,
@@ -704,31 +704,48 @@ def _capture_generation_cuda_graphs(self,
704704 draft_lengths .append (0 )
705705 draft_lengths = [self .max_total_draft_tokens ]
706706
707+ # Create CUDA graphs for short and long sequences separately for sparse attention.
708+ sparse_config = self .sparse_attention_config
709+ if sparse_config is not None and sparse_config .needs_separate_short_long_cuda_graphs (
710+ ):
711+ # For short sequences, use the (seq_len_threshold - max_draft_len - 1) as the maximum sequence length
712+ # to make sure all of the past and current input tokens are within the sequence length threshold.
713+ # For long sequences, use the default maximum sequence length (self.max_seq_len).
714+ max_seq_len = sparse_config .seq_len_threshold - (
715+ self .max_draft_len + 1 )
716+ if max_seq_len < self .max_seq_len :
717+ max_seq_len_list = [self .max_seq_len , max_seq_len ]
718+ else :
719+ max_seq_len_list = [self .max_seq_len ]
720+ else :
721+ max_seq_len_list = [self .max_seq_len ]
722+
707723 for bs in cuda_graph_batch_sizes :
708724 if bs > self .batch_size :
709725 continue
710726
711727 for draft_len in draft_lengths :
712- warmup_request = self ._create_cuda_graph_warmup_request (
713- resource_manager , bs , draft_len )
714- with self ._release_batch_context (warmup_request ,
715- resource_manager ) as batch :
716- if batch is None :
717- # No KV cache space, cannot continue capturing graphs
718- return
719-
720- logger .info (
721- f"Run generation-only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } "
722- )
723-
724- self .enable_spec_decode = draft_len > 0 or self .is_draft_model
725- self ._update_draft_inference_state_for_warmup (
726- batch , draft_len > 0 , resource_manager )
728+ for max_seq_len in max_seq_len_list :
729+ warmup_request = self ._create_cuda_graph_warmup_request (
730+ resource_manager , bs , draft_len , max_seq_len )
731+ with self ._release_batch_context (warmup_request ,
732+ resource_manager ) as batch :
733+ if batch is None :
734+ # No KV cache space, cannot continue capturing graphs
735+ return
736+
737+ logger .info (
738+ f"Run generation-only CUDA graph warmup for batch size={ bs } , draft_len={ draft_len } , max_seq_len={ max_seq_len } "
739+ )
740+
741+ self .enable_spec_decode = draft_len > 0 or self .is_draft_model
742+ self ._update_draft_inference_state_for_warmup (
743+ batch , draft_len > 0 , resource_manager )
727744
728- self .forward (batch ,
729- new_tensors_device = None ,
730- resource_manager = resource_manager )
731- torch .cuda .synchronize ()
745+ self .forward (batch ,
746+ new_tensors_device = None ,
747+ resource_manager = resource_manager )
748+ torch .cuda .synchronize ()
732749
733750 def _capture_piecewise_cuda_graphs (self , resource_manager : ResourceManager ):
734751 """Captures piecewise CUDA graphs for context/prefill steps via torch.compile."""
@@ -873,8 +890,11 @@ def _create_warmup_request(
873890 return result
874891
875892 def _create_cuda_graph_warmup_request (
876- self , resource_manager : ResourceManager , batch_size : int ,
877- draft_len : int ) -> Optional [ScheduledRequests ]:
893+ self ,
894+ resource_manager : ResourceManager ,
895+ batch_size : int ,
896+ draft_len : int ,
897+ max_seq_len : int = None ) -> Optional [ScheduledRequests ]:
878898 """Creates a dummy ScheduledRequests tailored for CUDA graph capture."""
879899 kv_cache_manager = resource_manager .get_resource_manager (
880900 self .kv_cache_manager_key )
@@ -902,7 +922,8 @@ def _create_cuda_graph_warmup_request(
902922 available_tokens = kv_cache_manager .get_num_available_tokens (draft_len )
903923
904924 # Add one dummy request with the maximum possible sequence length.
905- token_num = max (1 , min (available_tokens , self .max_seq_len - 1 ))
925+ max_seq_len = self .max_seq_len if max_seq_len is None else max_seq_len
926+ token_num = max (1 , min (available_tokens , max_seq_len - 1 ))
906927 model_config = self .model .model_config .pretrained_config
907928 max_position_embeddings = getattr (model_config ,
908929 'max_position_embeddings' , None )
@@ -1671,12 +1692,12 @@ def _prepare_tp_inputs(
16711692 # Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
16721693 if not self .is_warmup and not request .is_cuda_graph_dummy :
16731694 position_id = request .total_input_len_cp + request .py_decoding_iter - 1
1674- # TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
1675- if self .mapping .cp_rank == self .mapping .cp_size - 1 :
1676- past_seen_token_num = request .orig_prompt_len + request .py_decoding_iter - 1
1695+ if request .py_helix_is_inactive_rank :
1696+ past_seen_token_num = request .seqlen_this_rank_cp
16771697 else :
1678- # past_seen_token_num doesn't grow on inactive ranks.
1679- past_seen_token_num = request .orig_prompt_len
1698+ # Discount the token added to active rank in resource manager as it hasn't
1699+ # been previously seen.
1700+ past_seen_token_num = request .seqlen_this_rank_cp - 1
16801701
16811702 position_ids .append (position_id )
16821703 num_cached_tokens_per_seq .append (past_seen_token_num )
@@ -2015,6 +2036,11 @@ def previous_seq_slots_device():
20152036
20162037 attn_metadata .request_ids = request_ids
20172038 attn_metadata .prompt_lens = prompt_lengths
2039+ if helix_is_inactive_rank is not None and len (
2040+ helix_is_inactive_rank ) > 0 :
2041+ helix_is_inactive_rank = torch .tensor (helix_is_inactive_rank ,
2042+ dtype = torch .bool ,
2043+ device = 'cuda' )
20182044 attn_metadata .helix_is_inactive_rank = helix_is_inactive_rank
20192045 attn_metadata .num_contexts = len (scheduled_requests .context_requests )
20202046 # Use num_chunked_ctx_requests to record the number of extend context requests,
@@ -2089,6 +2115,9 @@ def previous_seq_slots_device():
20892115 num_accepted_draft_tokens )]
20902116 if isinstance (spec_metadata , Eagle3SpecMetadata ):
20912117 spec_metadata .request_accepted_path = request_accepted_path
2118+ if isinstance (spec_metadata , Eagle3OneModelSpecMetadata ):
2119+ spec_metadata .populate_sampling_params_for_one_model (
2120+ scheduled_requests .all_requests ())
20922121 spec_metadata .prepare ()
20932122 inputs ['spec_metadata' ] = spec_metadata
20942123
@@ -2643,7 +2672,7 @@ def forward(self,
26432672 # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
26442673 is_spec_dec_mode = spec_metadata .spec_dec_mode .attention_need_spec_dec_mode (
26452674 spec_resource_manager , self .is_draft_model , self .attn_backend ,
2646- self .model_is_wrapped , spec_metadata . is_spec_dec_tree )
2675+ self .model_is_wrapped )
26472676 attn_metadata .update_spec_dec_param (
26482677 batch_size = scheduled_requests .batch_size ,
26492678 is_spec_decoding_enabled = is_spec_dec_mode ,
@@ -2685,6 +2714,7 @@ def forward(self,
26852714 spec_metadata = spec_metadata ,
26862715 draft_tokens_cuda = self .draft_tokens_cuda
26872716 if self .is_spec_decode else None ,
2717+ new_tensors_device = new_tensors_device ,
26882718 spec_resource_manager = spec_resource_manager ,
26892719 )
26902720 can_run_graph = key is not None
@@ -2844,11 +2874,17 @@ def _init_userbuffers(self, hidden_size):
28442874 # Disable UB for unsupported platforms
28452875 if not ub .ub_supported ():
28462876 return False
2847- use_nccl_symmetric = self .llm_args .allreduce_strategy == "NCCL_SYMMETRIC"
2848- ub .initialize_userbuffers_manager (
2849- self .mapping .tp_size , self .mapping .pp_size , self .mapping .cp_size ,
2850- self .mapping .rank , self .mapping .gpus_per_node ,
2851- hidden_size * self .max_num_tokens * 2 , use_nccl_symmetric )
2877+ # NCCL_SYMMETRIC strategy no longer requires UserBuffer allocator initialization.
2878+ # It uses NCCLWindowAllocator from ncclUtils directly.
2879+ if self .llm_args .allreduce_strategy == "NCCL_SYMMETRIC" :
2880+ # Skip UB initialization for NCCL_SYMMETRIC - it uses NCCLWindowAllocator directly
2881+ return False
2882+ ub .initialize_userbuffers_manager (self .mapping .tp_size ,
2883+ self .mapping .pp_size ,
2884+ self .mapping .cp_size ,
2885+ self .mapping .rank ,
2886+ self .mapping .gpus_per_node ,
2887+ hidden_size * self .max_num_tokens * 2 )
28522888
28532889 return True
28542890
0 commit comments