-
Notifications
You must be signed in to change notification settings - Fork 576
feat: add NaN detection during training #4986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Changes from 5 commits
ef431a1
9eb1bea
5a22dfc
0852b7c
7a2b41e
22cb9ef
0bebb06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,9 @@ | |
| from deepmd.utils.data import ( | ||
| DataRequirementItem, | ||
| ) | ||
| from deepmd.utils.nan_detector import ( | ||
| check_total_loss_nan, | ||
| ) | ||
| from deepmd.utils.path import ( | ||
| DPH5Path, | ||
| ) | ||
|
|
@@ -859,6 +862,9 @@ def log_loss_valid(_task_key="Default"): | |
|
|
||
| if not self.multi_task: | ||
| train_results = log_loss_train(loss, more_loss) | ||
| # Check for NaN in total loss using CPU values from lcurve computation | ||
| if self.rank == 0 and "rmse" in train_results: | ||
| check_total_loss_nan(display_step_id, train_results["rmse"]) | ||
| valid_results = log_loss_valid() | ||
| if self.rank == 0: | ||
| log.info( | ||
|
|
@@ -900,6 +906,11 @@ def log_loss_valid(_task_key="Default"): | |
| loss, more_loss, _task_key=_key | ||
| ) | ||
| valid_results[_key] = log_loss_valid(_task_key=_key) | ||
| # Check for NaN in total loss using CPU values from lcurve computation | ||
| if self.rank == 0 and "rmse" in train_results[_key]: | ||
| check_total_loss_nan( | ||
| display_step_id, train_results[_key]["rmse"] | ||
|
||
| ) | ||
| if self.rank == 0: | ||
| log.info( | ||
| format_training_message_per_task( | ||
|
|
||
njzjz marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -75,6 +75,9 @@ | |||||||||||||
| from deepmd.utils.data import ( | ||||||||||||||
| DataRequirementItem, | ||||||||||||||
| ) | ||||||||||||||
| from deepmd.utils.nan_detector import ( | ||||||||||||||
| check_total_loss_nan, | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| if torch.__version__.startswith("2"): | ||||||||||||||
| import torch._dynamo | ||||||||||||||
|
|
@@ -949,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict: | |||||||||||||
|
|
||||||||||||||
| if not self.multi_task: | ||||||||||||||
| train_results = log_loss_train(loss, more_loss) | ||||||||||||||
| # Check for NaN in total loss using CPU values from lcurve computation | ||||||||||||||
| if self.rank == 0 and "rmse" in train_results: | ||||||||||||||
| check_total_loss_nan(display_step_id, train_results["rmse"]) | ||||||||||||||
|
Comment on lines
+956
to
+957
|
||||||||||||||
| if self.rank == 0 and "rmse" in train_results: | |
| check_total_loss_nan(display_step_id, train_results["rmse"]) | |
| if self.rank == 0: | |
| check_total_loss_nan(display_step_id, loss) |
Copilot
AI
Sep 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.
| if self.rank == 0 and "rmse" in train_results[_key]: | |
| check_total_loss_nan( | |
| display_step_id, train_results[_key]["rmse"] | |
| if self.rank == 0: | |
| check_total_loss_nan( | |
| display_step_id, loss |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -60,6 +60,9 @@ | |||||
| from deepmd.utils.data import ( | ||||||
| DataRequirementItem, | ||||||
| ) | ||||||
| from deepmd.utils.nan_detector import ( | ||||||
| check_total_loss_nan, | ||||||
| ) | ||||||
|
|
||||||
| log = logging.getLogger(__name__) | ||||||
|
|
||||||
|
|
@@ -684,6 +687,11 @@ def valid_on_the_fly( | |||||
|
|
||||||
| cur_batch = self.cur_batch | ||||||
| current_lr = run_sess(self.sess, self.learning_rate) | ||||||
|
|
||||||
| # Check for NaN in total loss before writing to file and saving checkpoint | ||||||
| # We check the main total loss component that represents training loss | ||||||
| check_total_loss_nan(cur_batch, train_results["rmse"]) | ||||||
|
||||||
| check_total_loss_nan(cur_batch, train_results["rmse"]) | |
| check_total_loss_nan(cur_batch, train_results["loss"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P1] Guard against missing 'rmse' metric in TensorFlow NaN check
NaN detection in valid_on_the_fly calls check_total_loss_nan(cur_batch, train_results["rmse"]) unconditionally. However get_evaluation_results often produces metrics keyed as rmse_e, rmse_f, etc., and does not guarantee a "rmse" entry (the comment below mentions rmse_*). In those configurations training now raises KeyError: 'rmse' before any logging or checkpointing, whereas the Paddle and PyTorch trainers already guard with "rmse" in train_results. TensorFlow should perform the same presence check or compute the appropriate scalar before invoking the NaN detector.
Useful? React with 👍 / 👎.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Utilities for detecting NaN values in loss during training.""" | ||
|
|
||
| import logging | ||
| import math | ||
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class LossNaNError(RuntimeError): | ||
| """Exception raised when NaN is detected in total loss during training.""" | ||
|
|
||
| def __init__(self, step: int, total_loss: float) -> None: | ||
| """Initialize the exception. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| step : int | ||
| The training step where NaN was detected | ||
| total_loss : float | ||
| The total loss value that contains NaN | ||
| """ | ||
| self.step = step | ||
| self.total_loss = total_loss | ||
| message = ( | ||
| f"NaN detected in total loss at training step {step}: {total_loss}. " | ||
| f"Training stopped to prevent wasting time with corrupted parameters. " | ||
| f"This typically indicates unstable training conditions such as " | ||
| f"learning rate too high, poor data quality, or numerical instability." | ||
| ) | ||
| super().__init__(message) | ||
|
|
||
|
|
||
| def check_total_loss_nan(step: int, total_loss: float) -> None: | ||
| """Check if the total loss contains NaN and raise an exception if found. | ||
|
|
||
| This function is designed to be called during training after the total loss | ||
| is computed and converted to a CPU float value. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| step : int | ||
| Current training step | ||
| total_loss : float | ||
| Total loss value to check for NaN | ||
|
|
||
| Raises | ||
| ------ | ||
| LossNaNError | ||
| If the total loss contains NaN | ||
| """ | ||
| if math.isnan(total_loss): | ||
| log.error(f"NaN detected in total loss at step {step}: {total_loss}") | ||
| raise LossNaNError(step, total_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.