@@ -872,6 +872,19 @@ def forward(
872872 else :
873873 q , k , v = qkv , None , None
874874
875+ # For dynamic tree spec decoding with Python RoPE, adjust position_ids
876+ # to use tree offsets (same as C++ kernel: past_seq_len + offset).
877+ if (not self .rope_fusion
878+ and getattr (attn_metadata , 'is_spec_dec_dynamic_tree' , False )
879+ and getattr (attn_metadata , 'use_spec_decoding' , False )
880+ and getattr (attn_metadata , 'spec_decoding_position_offsets' ,
881+ None ) is not None
882+ and attn_metadata .spec_decoding_position_offsets .dim () ==
883+ 1 # 1D layout ⇒ dynamic tree
884+ and position_ids is not None ):
885+ position_ids = self ._adjust_position_ids_for_spec_dec (
886+ position_ids , attn_metadata )
887+
875888 q , k , v = self .apply_rope (q , k , v , position_ids )
876889 q , k , v = self .convert_qkv (q , k , v )
877890
@@ -921,6 +934,25 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor],
921934 q , k = self .rotary_emb (position_ids , [q , k ])
922935 return q , k , v
923936
937+ def _adjust_position_ids_for_spec_dec (self , position_ids , attn_metadata ):
938+ """Replicate C++ kernel's rotary_pos = past_seq_len + offset."""
939+ num_contexts = attn_metadata .num_contexts
940+ num_gens = attn_metadata .num_seqs - num_contexts
941+ if num_gens <= 0 :
942+ return position_ids
943+ gen_len = int (attn_metadata .seq_lens [num_contexts ])
944+ base_pos = attn_metadata .kv_lens_cuda [num_contexts :num_contexts +
945+ num_gens ] - gen_len
946+ offsets = attn_metadata .spec_decoding_position_offsets [:num_gens *
947+ gen_len ].view (
948+ num_gens ,
949+ gen_len )
950+ adjusted = (base_pos .unsqueeze (1 ) + offsets ).reshape (- 1 )
951+ start = attn_metadata .num_ctx_tokens
952+ end = start + num_gens * gen_len
953+ position_ids [0 , start :end ] = adjusted
954+ return position_ids
955+
924956 def apply_qk_norm (self , q , k ):
925957 raise NotImplementedError (
926958 f"QK norm is not implemented for { self .__class__ .__name__ } . "
0 commit comments