diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 7b5b2cf0b14..99ecac85e83 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -895,8 +895,6 @@ def _create_warmup_request( return None num_extra_decoding_steps = self._get_num_extra_decoding_steps() - if num_extra_decoding_steps > 0: - return None # Disable autotuning for fused drafting loops for now. if num_gen_requests > self.batch_size: return None @@ -909,7 +907,10 @@ def _create_warmup_request( ctx_requests = [] gen_requests = [] - max_seq_len = self.max_seq_len - 1 + # For drafting loops, reduce max_seq_len to leave room for extra decoding steps + max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps + if max_seq_len < 1: + return None # Not enough sequence length for drafting loop num_full_seqs = 0 num_left_over_tokens = 0 @@ -954,7 +955,8 @@ def _create_warmup_request( token_nums=ctx_token_nums, is_gen=False, max_num_draft_tokens=self.runtime_draft_len, - use_mrope=self.use_mrope) + use_mrope=self.use_mrope, + num_extra_decoding_steps=num_extra_decoding_steps) if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests( @@ -1546,7 +1548,6 @@ def _prepare_incremental_update_metadata( return lora_params - @torch.compile(options={"max-autotune": True}) def _update_draft_input_tensors(self, num_accepted_tokens_device: torch.Tensor, new_tokens_device: torch.Tensor, @@ -1671,7 +1672,6 @@ def _apply_incremental_update_draft( return inputs, self.gather_ids_cuda[:num_generation_tokens] - @torch.compile(options={"max-autotune": True}) def _update_target_input_tensors( self, num_accepted_tokens_device: torch.Tensor, new_tokens_device: torch.Tensor, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3af32ebe4bd..bf42281abd4 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1708,7 +1708,6 @@ def _executor_loop_overlap(self): self.iter_counter += 1 @nvtx_range("_accept_draft_tokens") - @torch.compile(options={"max-autotune": True}) def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, target_outputs: SampleStateTensors, diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index 159cd9d528c..8c828c9864b 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -120,24 +120,27 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, new_draft_tokens = [self.sample(logits)] draft_logits = [logits] - with save_metadata_state(attn_metadata, spec_metadata): - batch_size = attn_metadata.num_seqs - - new_position_ids = self.prepare_for_generation( - attn_metadata, spec_metadata, position_ids) - for i in range(self.max_draft_len - 1): - logits = self.draft_model.forward( - input_ids=new_draft_tokens[-1], - position_ids=new_position_ids, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata) - new_draft_tokens.append(self.sample(logits)) - draft_logits.append(logits) - new_position_ids += 1 - attn_metadata.kv_lens_cuda[:batch_size] += 1 - if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata): - spec_metadata.hidden_states_read_indices[:batch_size].copy_( - spec_metadata.hidden_states_write_indices[:batch_size]) + if self.max_draft_len > 1: + is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata) + with save_metadata_state(attn_metadata, spec_metadata): + batch_size = attn_metadata.num_seqs + + new_position_ids = self.prepare_for_generation( + attn_metadata, spec_metadata, position_ids) + for i in range(self.max_draft_len - 1): + logits = self.draft_model.forward( + input_ids=new_draft_tokens[-1], + position_ids=new_position_ids, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata) + new_draft_tokens.append(self.sample(logits)) + draft_logits.append(logits) + new_position_ids += 1 + attn_metadata.kv_lens_cuda[:batch_size] += 1 + if i == 0 and is_eagle3: + spec_metadata.hidden_states_read_indices[:batch_size].copy_( + spec_metadata. + hidden_states_write_indices[:batch_size]) return { "new_draft_tokens": torch.stack(new_draft_tokens), @@ -153,7 +156,6 @@ def sample(self, logits: torch.Tensor) -> torch.Tensor: return tokens - @torch.compile(options={'max-autotune': True}) def prepare_for_generation(self, attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, position_ids: torch.Tensor) -> torch.Tensor: diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index f00a1ac1d20..6c73426a688 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -576,22 +576,53 @@ def _update_draft_tokens_for_target_inputs( if target_inputs.next_draft_tokens is None: return - if draft_tensors is not None: - for req_idx, request in enumerate(draft_batch.all_requests()): - target_req = self.req_id_to_old_request[request.py_request_id] - if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS: - # Skip prefill requests - continue - # Get the index of the draft/target tokens in the device tensor - draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot - target_idx = target_req.py_seq_slot - target_inputs.new_tokens[draft_position + 1:draft_position + - draft_length + 1, target_idx, - 0] = draft_tensors[0:draft_length, - draft_idx] - target_inputs.next_draft_tokens[ - target_idx, draft_position:draft_position + - draft_length] = draft_tensors[0:draft_length, draft_idx] + draft_indices = [] + target_indices = [] + for req_idx, request in enumerate(draft_batch.all_requests()): + target_req = self.req_id_to_old_request[request.py_request_id] + if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS: + # Skip prefill requests + continue + # Get the index of the draft/target tokens in the device tensor + draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot + target_idx = target_req.py_seq_slot + draft_indices.append(draft_idx) + target_indices.append(target_idx) + + if len(draft_indices) == 0: + return + + device = draft_tensors.device + + # Create index tensors + draft_indices_tensor = torch.tensor(draft_indices, + dtype=torch.long, + pin_memory=True).to( + device, non_blocking=True) + target_indices_tensor = torch.tensor(target_indices, + dtype=torch.long, + pin_memory=True).to( + device, non_blocking=True) + + # Pre-slice draft tensors: [draft_length, batch_size] + draft_slice = draft_tensors[0:draft_length] + + # Gather all source data at once using single index_select kernel + # Result shape: [draft_length, num_requests] + gathered = draft_slice.index_select(1, draft_indices_tensor).to( + torch.int32) + + # Scatter to new_tokens using advanced indexing (single kernel) + # Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width] + target_inputs.new_tokens[draft_position + 1:draft_position + + draft_length + 1, target_indices_tensor, + 0] = gathered + + # Scatter to next_draft_tokens using advanced indexing (single kernel) + # Shape: [num_requests, draft_length] -> [batch_size, max_draft_len] + target_inputs.next_draft_tokens[target_indices_tensor, + draft_position:draft_position + + draft_length] = gathered.t() def _setup_draft_batch_and_resources( self,