Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -859,6 +862,9 @@ def log_loss_valid(_task_key="Default"):

if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
# Check for NaN in total loss using CPU values from lcurve computation
if self.rank == 0 and "rmse" in train_results:
check_total_loss_nan(display_step_id, train_results["rmse"])
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.

Suggested change
check_total_loss_nan(display_step_id, train_results["rmse"])
check_total_loss_nan(display_step_id, loss)

Copilot uses AI. Check for mistakes.
valid_results = log_loss_valid()
if self.rank == 0:
log.info(
Expand Down Expand Up @@ -900,6 +906,11 @@ def log_loss_valid(_task_key="Default"):
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
# Check for NaN in total loss using CPU values from lcurve computation
if self.rank == 0 and "rmse" in train_results[_key]:
check_total_loss_nan(
display_step_id, train_results[_key]["rmse"]
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.

Copilot uses AI. Check for mistakes.
)
if self.rank == 0:
log.info(
format_training_message_per_task(
Expand Down
11 changes: 11 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -949,6 +952,9 @@ def log_loss_valid(_task_key: str = "Default") -> dict:

if not self.multi_task:
train_results = log_loss_train(loss, more_loss)
# Check for NaN in total loss using CPU values from lcurve computation
if self.rank == 0 and "rmse" in train_results:
check_total_loss_nan(display_step_id, train_results["rmse"])
Comment on lines +956 to +957
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.

Suggested change
if self.rank == 0 and "rmse" in train_results:
check_total_loss_nan(display_step_id, train_results["rmse"])
if self.rank == 0:
check_total_loss_nan(display_step_id, loss)

Copilot uses AI. Check for mistakes.
valid_results = log_loss_valid()
if self.rank == 0:
log.info(
Expand Down Expand Up @@ -997,6 +1003,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
loss, more_loss, _task_key=_key
)
valid_results[_key] = log_loss_valid(_task_key=_key)
# Check for NaN in total loss using CPU values from lcurve computation
if self.rank == 0 and "rmse" in train_results[_key]:
check_total_loss_nan(
display_step_id, train_results[_key]["rmse"]
Comment on lines +1007 to +1009
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.

Suggested change
if self.rank == 0 and "rmse" in train_results[_key]:
check_total_loss_nan(
display_step_id, train_results[_key]["rmse"]
if self.rank == 0:
check_total_loss_nan(
display_step_id, loss

Copilot uses AI. Check for mistakes.
)
if self.rank == 0:
log.info(
format_training_message_per_task(
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_total_loss_nan,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -684,6 +687,11 @@ def valid_on_the_fly(

cur_batch = self.cur_batch
current_lr = run_sess(self.sess, self.learning_rate)

# Check for NaN in total loss before writing to file and saving checkpoint
# We check the main total loss component that represents training loss
check_total_loss_nan(cur_batch, train_results["rmse"])
Copy link

Copilot AI Sep 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function is checking 'rmse' which represents root mean square error, not total loss. This could miss NaN in the actual total loss while falsely triggering on RMSE calculations. Consider using the actual total loss value instead of RMSE.

Suggested change
check_total_loss_nan(cur_batch, train_results["rmse"])
check_total_loss_nan(cur_batch, train_results["loss"])

Copilot uses AI. Check for mistakes.
Comment on lines +691 to +693

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Guard against missing 'rmse' metric in TensorFlow NaN check

NaN detection in valid_on_the_fly calls check_total_loss_nan(cur_batch, train_results["rmse"]) unconditionally. However get_evaluation_results often produces metrics keyed as rmse_e, rmse_f, etc., and does not guarantee a "rmse" entry (the comment below mentions rmse_*). In those configurations training now raises KeyError: 'rmse' before any logging or checkpointing, whereas the Paddle and PyTorch trainers already guard with "rmse" in train_results. TensorFlow should perform the same presence check or compute the appropriate scalar before invoking the NaN detector.

Useful? React with 👍 / 👎.


if print_header:
self.print_header(fp, train_results, valid_results)
self.print_on_training(
Expand Down
54 changes: 54 additions & 0 deletions deepmd/utils/nan_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for detecting NaN values in loss during training."""

import logging
import math

log = logging.getLogger(__name__)


class LossNaNError(RuntimeError):
"""Exception raised when NaN is detected in total loss during training."""

def __init__(self, step: int, total_loss: float) -> None:
"""Initialize the exception.

Parameters
----------
step : int
The training step where NaN was detected
total_loss : float
The total loss value that contains NaN
"""
self.step = step
self.total_loss = total_loss
message = (
f"NaN detected in total loss at training step {step}: {total_loss}. "
f"Training stopped to prevent wasting time with corrupted parameters. "
f"This typically indicates unstable training conditions such as "
f"learning rate too high, poor data quality, or numerical instability."
)
super().__init__(message)


def check_total_loss_nan(step: int, total_loss: float) -> None:
"""Check if the total loss contains NaN and raise an exception if found.

This function is designed to be called during training after the total loss
is computed and converted to a CPU float value.

Parameters
----------
step : int
Current training step
total_loss : float
Total loss value to check for NaN

Raises
------
LossNaNError
If the total loss contains NaN
"""
if math.isnan(total_loss):
log.error(f"NaN detected in total loss at step {step}: {total_loss}")
raise LossNaNError(step, total_loss)
Loading
Loading