Skip to content

Commit 26088df

Browse files
committed
[None][fix] Apply spec_decoding_position_offsets in Python RoPE path for EAGLE3 dynamic tree
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 95f8bcf commit 26088df

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

tensorrt_llm/_torch/modules/attention.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,19 @@ def forward(
872872
else:
873873
q, k, v = qkv, None, None
874874

875+
# For dynamic tree spec decoding with Python RoPE, adjust position_ids
876+
# to use tree offsets (same as C++ kernel: past_seq_len + offset).
877+
if (not self.rope_fusion
878+
and getattr(attn_metadata, 'is_spec_dec_dynamic_tree', False)
879+
and getattr(attn_metadata, 'use_spec_decoding', False)
880+
and getattr(attn_metadata, 'spec_decoding_position_offsets',
881+
None) is not None
882+
and attn_metadata.spec_decoding_position_offsets.dim() ==
883+
1 # 1D layout ⇒ dynamic tree
884+
and position_ids is not None):
885+
position_ids = self._adjust_position_ids_for_spec_dec(
886+
position_ids, attn_metadata)
887+
875888
q, k, v = self.apply_rope(q, k, v, position_ids)
876889
q, k, v = self.convert_qkv(q, k, v)
877890

@@ -921,6 +934,25 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
921934
q, k = self.rotary_emb(position_ids, [q, k])
922935
return q, k, v
923936

937+
def _adjust_position_ids_for_spec_dec(self, position_ids, attn_metadata):
938+
"""Replicate C++ kernel's rotary_pos = past_seq_len + offset."""
939+
num_contexts = attn_metadata.num_contexts
940+
num_gens = attn_metadata.num_seqs - num_contexts
941+
if num_gens <= 0:
942+
return position_ids
943+
gen_len = int(attn_metadata.seq_lens[num_contexts])
944+
base_pos = attn_metadata.kv_lens_cuda[num_contexts:num_contexts +
945+
num_gens] - gen_len
946+
offsets = attn_metadata.spec_decoding_position_offsets[:num_gens *
947+
gen_len].view(
948+
num_gens,
949+
gen_len)
950+
adjusted = (base_pos.unsqueeze(1) + offsets).reshape(-1)
951+
start = attn_metadata.num_ctx_tokens
952+
end = start + num_gens * gen_len
953+
position_ids[0, start:end] = adjusted
954+
return position_ids
955+
924956
def apply_qk_norm(self, q, k):
925957
raise NotImplementedError(
926958
f"QK norm is not implemented for {self.__class__.__name__}. "

tensorrt_llm/_torch/speculative/eagle3_dynamic_tree.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,9 @@ def __init__(
182182
(max_batch_size, K + K * K * (max_draft_len - 1)), dtype=torch.float32, device="cuda"
183183
)
184184
self.history_draft_tokens_parent_buffer = torch.zeros(
185-
(max_batch_size, K * (max_draft_len - 1) + 1), dtype=torch.int64, device="cuda"
185+
(max_batch_size, max(K * (max_draft_len - 1) + 1, K + 1)),
186+
dtype=torch.int64,
187+
device="cuda",
186188
)
187189
self.tree_mask_buffer = torch.zeros(
188190
(max_batch_size * loop_max_tokens * loop_max_tokens),

0 commit comments

Comments
 (0)