We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f8dc859 commit fa48315Copy full SHA for fa48315
modelopt/torch/speculative/plugins/megatron_eagle.py
@@ -1284,9 +1284,10 @@ def forward(
1284
1285
# If eagle_freeze_base_model is set to True,
1286
# the base model is frozen .
1287
- loss = self.compute_language_model_loss(
1288
- labels, logits_sbh[:-1] if self.eagle_offline else logits_sbh
1289
- )
+ if self.eagle_offline:
+ loss = torch.zeros(input_ids.shape).to(input_ids.device)
+ else:
1290
+ loss = self.compute_language_model_loss(labels, logits_sbh)
1291
loss = 0.0 * loss
1292
1293
if self.eagle_config.parallel_draft_step > 1:
0 commit comments