-
Notifications
You must be signed in to change notification settings - Fork 584
feat: add NaN detection during training #5135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fix deepmodeling#4985. This implementation is much simpler than deepmodeling#4986. Signed-off-by: Jinzhe Zeng <[email protected]>
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdded an optional NaN check to Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
📜 Recent review detailsConfiguration used: Repository UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds NaN (Not a Number) detection during model training to prevent wasting time training models that have already diverged. When the total RMSE becomes NaN, training is immediately stopped with a descriptive error message. The implementation adds a new parameter check_total_rmse_nan to the format_training_message_per_task function with a default value of True.
- Adds NaN detection logic to check the "rmse" key in the RMSE dictionary
- Logs an error message and raises RuntimeError when NaN is detected
- Adds comprehensive docstring to document the new functionality
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/loggers/training.py:
- Around line 56-62: The current NaN guard uses rmse.get("rmse", 0.0) which
silently defaults and skips detection when the "rmse" key is absent; change the
condition in the check_total_rmse_nan branch to explicitly test for the key and
then check NaN (e.g., only evaluate math.isnan on rmse["rmse"] when "rmse" in
rmse), i.e., replace the get(...) usage with an explicit membership check of the
rmse dict before calling math.isnan, keeping the existing log.error and
RuntimeError behavior for true NaN.
🧹 Nitpick comments (1)
deepmd/loggers/training.py (1)
28-42: Fix grammar in docstring.Line 41 has a grammatical error. It should read "Whether to throw an error" instead of "Whether throw the error".
📝 Proposed fix
- check_total_rmse_nan : bool - Whether throw the error if the total RMSE is NaN + check_total_rmse_nan : bool + Whether to throw an error if the total RMSE is NaN
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/loggers/training.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-12T13:40:14.334Z
Learnt from: CR
Repo: deepmodeling/deepmd-kit PR: 0
File: AGENTS.md:0-0
Timestamp: 2025-12-12T13:40:14.334Z
Learning: Verify PyTorch training output shows decreasing loss with 'batch X: trn: rmse' messages
Applied to files:
deepmd/loggers/training.py
🧬 Code graph analysis (1)
deepmd/loggers/training.py (3)
deepmd/pd/utils/stat.py (1)
rmse(536-537)deepmd/pt/utils/stat.py (1)
rmse(536-537)deepmd/utils/pair_tab.py (1)
get(207-209)
🔇 Additional comments (2)
deepmd/loggers/training.py (2)
3-6: LGTM!The import additions and logger initialization follow Python best practices. Using
logging.getLogger(__name__)ensures proper logger hierarchy.
51-63: Good implementation approach with clear error handling.The implementation correctly:
- Constructs the message with all RMSE values for visibility
- Logs the message before raising the error so users can see the problematic values
- Raises a descriptive RuntimeError that stops training
- Defaults to enabled (safer choice)
The sequence of operations is appropriate: format the message, check for NaN, log if found, then raise. This provides good visibility into what went wrong.
Co-authored-by: Copilot <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5135 +/- ##
==========================================
- Coverage 82.15% 82.14% -0.01%
==========================================
Files 709 709
Lines 72470 72478 +8
Branches 3616 3615 -1
==========================================
+ Hits 59535 59540 +5
- Misses 11771 11775 +4
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Fix #4985.
This implementation is much simpler than #4986.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.