@@ -120,24 +120,27 @@ def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
120120
121121 new_draft_tokens = [self .sample (logits )]
122122 draft_logits = [logits ]
123- with save_metadata_state (attn_metadata , spec_metadata ):
124- batch_size = attn_metadata .num_seqs
125-
126- new_position_ids = self .prepare_for_generation (
127- attn_metadata , spec_metadata , position_ids )
128- for i in range (self .max_draft_len - 1 ):
129- logits = self .draft_model .forward (
130- input_ids = new_draft_tokens [- 1 ],
131- position_ids = new_position_ids ,
132- attn_metadata = attn_metadata ,
133- spec_metadata = spec_metadata )
134- new_draft_tokens .append (self .sample (logits ))
135- draft_logits .append (logits )
136- new_position_ids += 1
137- attn_metadata .kv_lens_cuda [:batch_size ] += 1
138- if i == 0 and isinstance (spec_metadata , Eagle3SpecMetadata ):
139- spec_metadata .hidden_states_read_indices [:batch_size ].copy_ (
140- spec_metadata .hidden_states_write_indices [:batch_size ])
123+ if self .max_draft_len > 1 :
124+ is_eagle3 = isinstance (spec_metadata , Eagle3SpecMetadata )
125+ with save_metadata_state (attn_metadata , spec_metadata ):
126+ batch_size = attn_metadata .num_seqs
127+
128+ new_position_ids = self .prepare_for_generation (
129+ attn_metadata , spec_metadata , position_ids )
130+ for i in range (self .max_draft_len - 1 ):
131+ logits = self .draft_model .forward (
132+ input_ids = new_draft_tokens [- 1 ],
133+ position_ids = new_position_ids ,
134+ attn_metadata = attn_metadata ,
135+ spec_metadata = spec_metadata )
136+ new_draft_tokens .append (self .sample (logits ))
137+ draft_logits .append (logits )
138+ new_position_ids += 1
139+ attn_metadata .kv_lens_cuda [:batch_size ] += 1
140+ if i == 0 and is_eagle3 :
141+ spec_metadata .hidden_states_read_indices [:batch_size ].copy_ (
142+ spec_metadata .
143+ hidden_states_write_indices [:batch_size ])
141144
142145 return {
143146 "new_draft_tokens" : torch .stack (new_draft_tokens ),
@@ -153,7 +156,6 @@ def sample(self, logits: torch.Tensor) -> torch.Tensor:
153156
154157 return tokens
155158
156- @torch .compile (options = {'max-autotune' : True })
157159 def prepare_for_generation (self , attn_metadata : AttentionMetadata ,
158160 spec_metadata : SpecMetadata ,
159161 position_ids : torch .Tensor ) -> torch .Tensor :
0 commit comments