Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions deepmd/dpmodel/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
Any,
)

import array_api_compat
import numpy as np

from deepmd.common import (
j_get_type,
)
from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
Expand Down Expand Up @@ -44,8 +48,9 @@ def __init__(
self.stop_steps = stop_steps

@abstractmethod
def value(self, step: int) -> np.float64:
def value(self, step: int | Array) -> Array:
"""Get the learning rate at the given step."""
# in optax, step will be a jnp.ndarray passed in JIT mode
pass


Expand Down Expand Up @@ -88,16 +93,23 @@ def __init__(
self.decay_steps = default_ds
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
).item()
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = self.stop_lr

def value(self, step: int) -> np.float64:
def value(self, step: int | Array) -> Array:
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
step_lr = self.min_lr
if not array_api_compat.is_array_api_obj(step):
step = np.asarray(step)
xp = array_api_compat.array_namespace(step)
step_lr = self.start_lr * xp.pow(
xp.asarray(self.decay_rate, device=array_api_compat.device(step)),
xp.astype(step // self.decay_steps, xp.float64),
)
# the original implementation `if step_lr < self.min_lr:`
# will cause a dynamic graph which is unsupported in JAX JIT
step_lr = xp.clip(step_lr, self.min_lr, None)
return step_lr


Expand Down Expand Up @@ -128,12 +140,24 @@ def __init__(
super().__init__(start_lr, stop_lr, stop_steps, **kwargs)
self.lr_min_factor = stop_lr / start_lr

def value(self, step: int) -> np.float64:
if step >= self.stop_steps:
return self.start_lr * self.lr_min_factor
return self.start_lr * (
def value(self, step: int | Array) -> Array:
if not array_api_compat.is_array_api_obj(step):
step = np.asarray(step)
xp = array_api_compat.array_namespace(step)
min_lr = self.start_lr * self.lr_min_factor
step_lr = self.start_lr * (
self.lr_min_factor
+ 0.5
* (1 - self.lr_min_factor)
* (1 + np.cos(np.pi * (step / self.stop_steps)))
* (
1
+ xp.cos(
xp.asarray(
xp.pi * (xp.astype(step, xp.float64) / self.stop_steps),
device=array_api_compat.device(step),
)
)
)
)
step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr)
return step_lr
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ test = [
"pytest-sugar",
"pytest-split",
"dpgui",
'array-api-strict>=2,!=2.1.1;python_version>="3.9"',
# to support Array API 2024.12
'array-api-strict>=2.2;python_version>="3.9"',
]
docs = [
"sphinx>=3.1.1",
Expand Down
2 changes: 1 addition & 1 deletion source/tests/array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@

# this is the default version in the latest array_api_strict,
# but in old versions it may be 2022.12
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
array_api_strict.set_array_api_strict_flags(api_version="2024.12")
84 changes: 84 additions & 0 deletions source/tests/consistent/test_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import unittest

import numpy as np

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.learning_rate import (
BaseLR,
)

from .common import (
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
parameterized,
)

if INSTALLED_PT:
from deepmd.pt.utils.utils import (
to_torch_tensor,
)

if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
)
if INSTALLED_ARRAY_API_STRICT:
import array_api_strict as xp


@parameterized(
(
{
"type": "exp",
"start_lr": 1e-3,
"stop_lr": 1e-8,
"decay_steps": 1000,
"stop_steps": 1000000,
},
{
"type": "cosine",
"start_lr": 1e-3,
"stop_lr": 1e-8,
"decay_steps": 1000,
"stop_steps": 1000000,
},
),
)
class TestLearningRateConsistent(unittest.TestCase):
def setUp(self) -> None:
(lr_param,) = self.param
self.lr = BaseLR(**lr_param)
self.step = 500000
self.ref = self.lr.value(self.step)

def compare_test_with_ref(self, step: Array) -> None:
test = self.lr.value(step)
np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10)

def compare_numpy_with_ref(self, step: Array) -> None:
self.compare_test_with_ref(np.asarray(step))

@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
def test_pt_consistent_with_ref(self) -> None:
self.compare_test_with_ref(to_torch_tensor(self.step))

@unittest.skipUnless(
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
)
@unittest.skipUnless(
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
)
def test_array_api_strict(self) -> None:
self.compare_test_with_ref(xp.asarray(self.step))

@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
def test_jax_consistent_with_ref(self) -> None:
self.compare_test_with_ref(jnp.array(self.step))