Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@
ABC,
abstractmethod,
)
from types import ModuleType
from typing import (
Any,
overload,
override,
)

import array_api_compat
import array_api_compat
import numpy as np

from deepmd.common import (
j_get_type,
)
from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
Expand Down Expand Up @@ -44,10 +52,26 @@ def __init__(
self.stop_steps = stop_steps

@abstractmethod
def value(self, step: int) -> np.float64:
def value(self, step: int | Array) -> Array:
"""Get the learning rate at the given step."""
# in optax, step will be a jnp.ndarray passed in JIT mode
pass

@overload
def array_namespace(self, step: int) -> ModuleType: ...
@overload
def array_namespace(self, step: Array) -> Any: ...

def array_namespace(self, step: int | Array) -> Any:
"""Get the array API namespace based on the type of step.

If the step is int, use NumPy.
"""
if array_api_compat.is_array_api_obj(step):
xp = array_api_compat.array_namespace(step)
return xp
return np


@BaseLR.register("exp")
class LearningRateExp(BaseLR):
Expand Down Expand Up @@ -88,16 +112,20 @@ def __init__(
self.decay_steps = default_ds
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
).item()
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = self.stop_lr

def value(self, step: int) -> np.float64:
def value(self, step: int | Array) -> Array:
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
step_lr = self.min_lr
xp = self.array_namespace(step)
step_lr = self.start_lr * xp.pow(
xp.asarray(self.decay_rate), step // self.decay_steps
)
# the original implementation `if step_lr < self.min_lr:`
# will cause a dynamic graph which is unsupported in JAX JIT
step_lr = xp.clip(step_lr, self.min_lr, None)
return step_lr


Expand Down Expand Up @@ -128,12 +156,14 @@ def __init__(
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
self.lr_min_factor = stop_lr / start_lr

def value(self, step: int) -> np.float64:
if step >= self.stop_steps:
return self.start_lr * self.lr_min_factor
return self.start_lr * (
def value(self, step: int | Array) -> Array:
xp = self.array_namespace(step)
min_lr = self.start_lr * self.lr_min_factor
step_lr = self.start_lr * (
self.lr_min_factor
+ 0.5
* (1 - self.lr_min_factor)
* (1 + np.cos(np.pi * (step / self.stop_steps)))
* (1 + xp.cos(xp.asarray(xp.pi * (step / self.stop_steps))))
)
step_lr = xp.where(xp.asarray(step) >= self.stop_steps, min_lr, step_lr)
return step_lr
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ test = [
"pytest-sugar",
"pytest-split",
"dpgui",
'array-api-strict>=2,!=2.1.1;python_version>="3.9"',
# to support Array API 2024.12
'array-api-strict>=2.2;python_version>="3.9"',
]
docs = [
"sphinx>=3.1.1",
Expand Down
2 changes: 1 addition & 1 deletion source/tests/array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

# this is the default version in the latest array_api_strict,
# but in old versions it may be 2022.12
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
array_api_strict.set_array_api_strict_flags(api_version="2024.12")
80 changes: 80 additions & 0 deletions source/tests/consistent/test_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import unittest
from typing import (
Any,
)

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.learning_rate import (
BaseLR,
)

from .common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
parameterized,
)

if INSTALLED_PT:
import array_api_compat.torch as torch_xp
import torch
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict as xp


@parameterized(
(
{
"type": "exp",
"start_lr": 1e-3,
"stop_lr": 1e-8,
"decay_steps": 1000,
"stop_steps": 1000000,
},
{
"type": "cosine",
"start_lr": 1e-3,
"stop_lr": 1e-8,
"decay_steps": 1000,
"stop_steps": 1000000,
},
),
)
class TestLearningRateConsistent(unittest.TestCase):
def setUp(self) -> None:
(lr_param,) = self.param
self.lr = BaseLR(**lr_param)
self.step = 500000
self.ref = self.lr.value(self.step, xp=np)

def compare_test_with_ref(self, xp: Any) -> None:
test = self.lr.value(self.step, xp=xp)
np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10)

@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
def test_pt_consistent_with_ref(self) -> None:
with torch.device("cpu"):
self.compare_test_with_ref(torch_xp)

@unittest.skipUnless(
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
)
@unittest.skipUnless(
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
)
def test_array_api_strict(self) -> None:
self.compare_test_with_ref(xp)

@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
def test_jax_consistent_with_ref(self) -> None:
self.compare_test_with_ref(jnp)
Loading