Skip to content

Commit 501554b

Browse files
authored
add_mtp_loss_log (#3960)
1 parent fd47662 commit 501554b

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

paddleformers/trainer/trainer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2546,6 +2546,20 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
25462546
)
25472547
)
25482548
logs.update(self.global_training_logs)
2549+
2550+
# Add MTP loss metrics if available
2551+
try:
2552+
from paddlefleet.models.common.language_loss.language_loss import (
2553+
LanguageLoss,
2554+
)
2555+
2556+
if LanguageLoss.mtp_loss_tracker:
2557+
logs.update(
2558+
{k: v.item() if hasattr(v, "item") else v for k, v in LanguageLoss.mtp_loss_tracker.items()}
2559+
)
2560+
except (ImportError, AttributeError):
2561+
pass
2562+
25492563
self._total_loss_scalar += tr_loss_scalar
25502564
self._globalstep_last_logged = self.state.global_step
25512565
self._globalstep_last_start_time = time.time()

0 commit comments

Comments
 (0)