Skip to content

Commit 370296e

Browse files
committed
Remove torch.compile
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 237fd0e commit 370296e

File tree

3 files changed

+21
-22
lines changed

3 files changed

+21
-22
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,6 @@ def _prepare_incremental_update_metadata(
14741474

14751475
return lora_params
14761476

1477-
@torch.compile(options={"max-autotune": True})
14781477
def _update_draft_input_tensors(self,
14791478
num_accepted_tokens_device: torch.Tensor,
14801479
new_tokens_device: torch.Tensor,
@@ -1599,7 +1598,6 @@ def _apply_incremental_update_draft(
15991598

16001599
return inputs, self.gather_ids_cuda[:num_generation_tokens]
16011600

1602-
@torch.compile(options={"max-autotune": True})
16031601
def _update_target_input_tensors(
16041602
self, num_accepted_tokens_device: torch.Tensor,
16051603
new_tokens_device: torch.Tensor,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1687,7 +1687,6 @@ def _executor_loop_overlap(self):
16871687
self.iter_counter += 1
16881688

16891689
@nvtx_range("_accept_draft_tokens")
1690-
@torch.compile(options={"max-autotune": True})
16911690
def _accept_draft_tokens(
16921691
self, scheduled_batch: ScheduledRequests,
16931692
target_outputs: SampleStateTensors,

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,24 +120,27 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
120120

121121
new_draft_tokens = [self.sample(logits)]
122122
draft_logits = [logits]
123-
with save_metadata_state(attn_metadata, spec_metadata):
124-
batch_size = attn_metadata.num_seqs
125-
126-
new_position_ids = self.prepare_for_generation(
127-
attn_metadata, spec_metadata, position_ids)
128-
for i in range(self.max_draft_len - 1):
129-
logits = self.draft_model.forward(
130-
input_ids=new_draft_tokens[-1],
131-
position_ids=new_position_ids,
132-
attn_metadata=attn_metadata,
133-
spec_metadata=spec_metadata)
134-
new_draft_tokens.append(self.sample(logits))
135-
draft_logits.append(logits)
136-
new_position_ids += 1
137-
attn_metadata.kv_lens_cuda[:batch_size] += 1
138-
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
139-
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
140-
spec_metadata.hidden_states_write_indices[:batch_size])
123+
if self.max_draft_len > 1:
124+
is_eagle3 = isinstance(spec_metadata, Eagle3SpecMetadata)
125+
with save_metadata_state(attn_metadata, spec_metadata):
126+
batch_size = attn_metadata.num_seqs
127+
128+
new_position_ids = self.prepare_for_generation(
129+
attn_metadata, spec_metadata, position_ids)
130+
for i in range(self.max_draft_len - 1):
131+
logits = self.draft_model.forward(
132+
input_ids=new_draft_tokens[-1],
133+
position_ids=new_position_ids,
134+
attn_metadata=attn_metadata,
135+
spec_metadata=spec_metadata)
136+
new_draft_tokens.append(self.sample(logits))
137+
draft_logits.append(logits)
138+
new_position_ids += 1
139+
attn_metadata.kv_lens_cuda[:batch_size] += 1
140+
if i == 0 and is_eagle3:
141+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
142+
spec_metadata.
143+
hidden_states_write_indices[:batch_size])
141144

142145
return {
143146
"new_draft_tokens": torch.stack(new_draft_tokens),
@@ -153,7 +156,6 @@ def sample(self, logits: torch.Tensor) -> torch.Tensor:
153156

154157
return tokens
155158

156-
@torch.compile(options={'max-autotune': True})
157159
def prepare_for_generation(self, attn_metadata: AttentionMetadata,
158160
spec_metadata: SpecMetadata,
159161
position_ids: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)