Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_loss_nan,
)
from deepmd.utils.path import (
DPH5Path,
)
Expand Down Expand Up @@ -951,6 +954,20 @@ def log_loss_valid(_task_key="Default"):
fout, display_step_id, cur_lr, train_results, valid_results
)

# Check for NaN in loss values before saving checkpoint
# Loss values are already on CPU at this point for display/logging
if self.rank == 0:
if not self.multi_task:
check_loss_nan(display_step_id, train_results)
if valid_results:
check_loss_nan(display_step_id, valid_results)
else:
for task_key in train_results:
if train_results[task_key]:
check_loss_nan(display_step_id, train_results[task_key])
if valid_results[task_key]:
check_loss_nan(display_step_id, valid_results[task_key])

if (
((_step_id + 1) % self.save_freq == 0 and _step_id != self.start_step)
or (_step_id + 1) == self.num_steps
Expand Down
17 changes: 17 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_loss_nan,
)

if torch.__version__.startswith("2"):
import torch._dynamo
Expand Down Expand Up @@ -1070,6 +1073,20 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
fout, display_step_id, cur_lr, train_results, valid_results
)

# Check for NaN in loss values before saving checkpoint
# Loss values are already on CPU at this point for display/logging
if self.rank == 0:
if not self.multi_task:
check_loss_nan(display_step_id, train_results)
if valid_results:
check_loss_nan(display_step_id, valid_results)
else:
for task_key in train_results:
if train_results[task_key]:
check_loss_nan(display_step_id, train_results[task_key])
if valid_results[task_key]:
check_loss_nan(display_step_id, valid_results[task_key])

if (
(
(display_step_id) % self.save_freq == 0
Expand Down
10 changes: 10 additions & 0 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.utils.nan_detector import (
check_loss_nan,
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -684,6 +687,13 @@ def valid_on_the_fly(

cur_batch = self.cur_batch
current_lr = run_sess(self.sess, self.learning_rate)

# Check for NaN in loss values before writing to file and saving checkpoint
# Loss values are already on CPU at this point
check_loss_nan(cur_batch, train_results)
if valid_results is not None:
check_loss_nan(cur_batch, valid_results)

if print_header:
self.print_header(fp, train_results, valid_results)
self.print_on_training(
Expand Down
119 changes: 119 additions & 0 deletions deepmd/utils/nan_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for detecting NaN values in loss during training."""

import logging
import math
from typing import (
Any,
)

import numpy as np

log = logging.getLogger(__name__)


class LossNaNError(Exception):
"""Exception raised when NaN is detected in loss during training."""

def __init__(self, step: int, loss_dict: dict[str, Any]) -> None:
"""Initialize the exception.

Parameters
----------
step : int
The training step where NaN was detected
loss_dict : dict[str, Any]
Dictionary containing the loss values where NaN was found
"""
self.step = step
self.loss_dict = loss_dict
super().__init__(self._format_message())

def _format_message(self) -> str:
"""Format the error message."""
nan_losses = []
for key, value in self.loss_dict.items():
if self._is_nan(value):
nan_losses.append(f"{key}={value}")

message = (
f"NaN detected in loss at training step {self.step}. "
f"Training stopped to prevent wasting time with corrupted parameters. "
f"NaN values found in: {', '.join(nan_losses)}. "
f"This typically indicates unstable training conditions such as "
f"learning rate too high, poor data quality, or numerical instability."
)
return message

@staticmethod
def _is_nan(value: Any) -> bool:
"""Check if a value is NaN."""
if value is None:
return False
try:
# Handle various tensor types and Python scalars
if hasattr(value, "item"):
# PyTorch/TensorFlow/PaddlePaddle tensor
return math.isnan(value.item())
elif isinstance(value, (int, float)):
# Python scalar
return math.isnan(value)
elif isinstance(value, np.ndarray):
# NumPy array
return np.isnan(value).any()
else:
# Try to convert to float and check
return math.isnan(float(value))
except (TypeError, ValueError):
# If we can't convert to float, assume it's not NaN
return False


def check_loss_nan(step: int, loss_dict: dict[str, Any]) -> None:
"""Check if any loss values contain NaN and raise an exception if found.

This function is designed to be called during training after loss values
are computed and available on CPU, typically during the logging/display phase.

Parameters
----------
step : int
Current training step
loss_dict : dict[str, Any]
Dictionary containing loss values to check for NaN

Raises
------
LossNaNError
If any loss value contains NaN
"""
nan_found = False
for key, value in loss_dict.items():
if LossNaNError._is_nan(value):
nan_found = True
log.error(f"NaN detected in {key} at step {step}: {value}")

if nan_found:
raise LossNaNError(step, loss_dict)


def check_single_loss_nan(step: int, loss_name: str, loss_value: Any) -> None:
"""Check if a single loss value contains NaN and raise an exception if found.

Parameters
----------
step : int
Current training step
loss_name : str
Name/identifier of the loss
loss_value : Any
Loss value to check for NaN

Raises
------
LossNaNError
If the loss value contains NaN
"""
if LossNaNError._is_nan(loss_value):
log.error(f"NaN detected in {loss_name} at step {step}: {loss_value}")
raise LossNaNError(step, {loss_name: loss_value})
Loading