Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,8 +895,6 @@ def _create_warmup_request(
return None

num_extra_decoding_steps = self._get_num_extra_decoding_steps()
if num_extra_decoding_steps > 0:
return None # Disable autotuning for fused drafting loops for now.

if num_gen_requests > self.batch_size:
return None
Expand All @@ -909,7 +907,10 @@ def _create_warmup_request(
ctx_requests = []
gen_requests = []

max_seq_len = self.max_seq_len - 1
# For drafting loops, reduce max_seq_len to leave room for extra decoding steps
max_seq_len = self.max_seq_len - 1 - num_extra_decoding_steps
if max_seq_len < 1:
return None # Not enough sequence length for drafting loop
num_full_seqs = 0
num_left_over_tokens = 0

Expand Down Expand Up @@ -954,7 +955,8 @@ def _create_warmup_request(
token_nums=ctx_token_nums,
is_gen=False,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope)
use_mrope=self.use_mrope,
num_extra_decoding_steps=num_extra_decoding_steps)

if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
Expand Down Expand Up @@ -1546,7 +1548,6 @@ def _prepare_incremental_update_metadata(

return lora_params

@torch.compile(options={"max-autotune": True})
def _update_draft_input_tensors(self,
num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor,
Expand Down Expand Up @@ -1671,7 +1672,6 @@ def _apply_incremental_update_draft(

return inputs, self.gather_ids_cuda[:num_generation_tokens]

@torch.compile(options={"max-autotune": True})
def _update_target_input_tensors(
self, num_accepted_tokens_device: torch.Tensor,
new_tokens_device: torch.Tensor,
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,6 @@ def _executor_loop_overlap(self):
self.iter_counter += 1

@nvtx_range("_accept_draft_tokens")
@torch.compile(options={"max-autotune": True})
def _accept_draft_tokens(
self, scheduled_batch: ScheduledRequests,
target_outputs: SampleStateTensors,
Expand Down
40 changes: 21 additions & 19 deletions tensorrt_llm/_torch/speculative/drafting_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,27 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,

new_draft_tokens = [self.sample(logits)]
draft_logits = [logits]
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs

new_position_ids = self.prepare_for_generation(
attn_metadata, spec_metadata, position_ids)
for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1],
position_ids=new_position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits)
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.hidden_states_write_indices[:batch_size])
if self.max_draft_len > 1:
is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
with save_metadata_state(attn_metadata, spec_metadata):
batch_size = attn_metadata.num_seqs

new_position_ids = self.prepare_for_generation(
attn_metadata, spec_metadata, position_ids)
for i in range(self.max_draft_len - 1):
logits = self.draft_model.forward(
input_ids=new_draft_tokens[-1],
position_ids=new_position_ids,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
new_draft_tokens.append(self.sample(logits))
draft_logits.append(logits)
new_position_ids += 1
attn_metadata.kv_lens_cuda[:batch_size] += 1
if i == 0 and is_eagle3:
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
spec_metadata.
hidden_states_write_indices[:batch_size])

return {
"new_draft_tokens": torch.stack(new_draft_tokens),
Expand All @@ -153,7 +156,6 @@ def sample(self, logits: torch.Tensor) -> torch.Tensor:

return tokens

@torch.compile(options={'max-autotune': True})
def prepare_for_generation(self, attn_metadata: AttentionMetadata,
spec_metadata: SpecMetadata,
position_ids: torch.Tensor) -> torch.Tensor:
Expand Down
63 changes: 47 additions & 16 deletions tensorrt_llm/_torch/speculative/model_drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,22 +576,53 @@ def _update_draft_tokens_for_target_inputs(
if target_inputs.next_draft_tokens is None:
return

if draft_tensors is not None:
for req_idx, request in enumerate(draft_batch.all_requests()):
target_req = self.req_id_to_old_request[request.py_request_id]
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Skip prefill requests
continue
# Get the index of the draft/target tokens in the device tensor
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_idx = target_req.py_seq_slot
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_idx,
0] = draft_tensors[0:draft_length,
draft_idx]
target_inputs.next_draft_tokens[
target_idx, draft_position:draft_position +
draft_length] = draft_tensors[0:draft_length, draft_idx]
draft_indices = []
target_indices = []
for req_idx, request in enumerate(draft_batch.all_requests()):
target_req = self.req_id_to_old_request[request.py_request_id]
if target_req.state != LlmRequestState.GENERATION_IN_PROGRESS:
# Skip prefill requests
continue
# Get the index of the draft/target tokens in the device tensor
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
target_idx = target_req.py_seq_slot
draft_indices.append(draft_idx)
target_indices.append(target_idx)

if len(draft_indices) == 0:
return

device = draft_tensors.device

# Create index tensors
draft_indices_tensor = torch.tensor(draft_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)
target_indices_tensor = torch.tensor(target_indices,
dtype=torch.long,
pin_memory=True).to(
device, non_blocking=True)

# Pre-slice draft tensors: [draft_length, batch_size]
draft_slice = draft_tensors[0:draft_length]

# Gather all source data at once using single index_select kernel
# Result shape: [draft_length, num_requests]
gathered = draft_slice.index_select(1, draft_indices_tensor).to(
torch.int32)

# Scatter to new_tokens using advanced indexing (single kernel)
# Shape: [draft_length, num_requests] -> [seq_len, batch_size, beam_width]
target_inputs.new_tokens[draft_position + 1:draft_position +
draft_length + 1, target_indices_tensor,
0] = gathered

# Scatter to next_draft_tokens using advanced indexing (single kernel)
# Shape: [num_requests, draft_length] -> [batch_size, max_draft_len]
target_inputs.next_draft_tokens[target_indices_tensor,
draft_position:draft_position +
draft_length] = gathered.t()

def _setup_draft_batch_and_resources(
self,
Expand Down
Loading