Skip to content

Commit dc27788

Browse files
committed
apply suggestion to cover eagle1 case
Signed-off-by: Ye Yu <[email protected]>
1 parent 529a2f2 commit dc27788

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modelopt/torch/speculative/plugins/megatron_eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def modify(
844844
self.kld = logits_kld_loss
845845

846846
def _get_eagle_input_hidden_states(self, hidden_states: torch.Tensor, apply_fc: bool = True):
847-
"""When _aux_hidden_states is not empty, then this is EAGLE-3.
847+
"""When _aux_hidden_states is not empty for online, then this is EAGLE-3.
848848
849849
Args:
850850
hidden_states: last hidden_states
@@ -1234,7 +1234,7 @@ def forward(
12341234

12351235
if self.eagle_offline:
12361236
eagle_module_input_hidden_states = self._get_eagle_input_hidden_states(
1237-
aux_hidden_states, apply_fc=True
1237+
aux_hidden_states, apply_fc=self.eagle_config.use_aux_hidden_state
12381238
)
12391239
# If EAGLE-3, aux_hidden_states are gathered by the forward_hook
12401240
elif return_eagle_inputs:

0 commit comments

Comments
 (0)