File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
modelopt/torch/speculative/plugins Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -844,7 +844,7 @@ def modify(
844
844
self .kld = logits_kld_loss
845
845
846
846
def _get_eagle_input_hidden_states (self , hidden_states : torch .Tensor , apply_fc : bool = True ):
847
- """When _aux_hidden_states is not empty, then this is EAGLE-3.
847
+ """When _aux_hidden_states is not empty for online , then this is EAGLE-3.
848
848
849
849
Args:
850
850
hidden_states: last hidden_states
@@ -1234,7 +1234,7 @@ def forward(
1234
1234
1235
1235
if self .eagle_offline :
1236
1236
eagle_module_input_hidden_states = self ._get_eagle_input_hidden_states (
1237
- aux_hidden_states , apply_fc = True
1237
+ aux_hidden_states , apply_fc = self . eagle_config . use_aux_hidden_state
1238
1238
)
1239
1239
# If EAGLE-3, aux_hidden_states are gathered by the forward_hook
1240
1240
elif return_eagle_inputs :
You can’t perform that action at this time.
0 commit comments