diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index f82a42660b..7ea50583e2 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -7,11 +7,15 @@ Any, ) +import array_api_compat import numpy as np from deepmd.common import ( j_get_type, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.utils.plugin import ( PluginVariant, make_plugin_registry, @@ -44,8 +48,9 @@ def __init__( self.stop_steps = stop_steps @abstractmethod - def value(self, step: int) -> np.float64: + def value(self, step: int | Array) -> Array: """Get the learning rate at the given step.""" + # in optax, step will be a jnp.ndarray passed in JIT mode pass @@ -88,16 +93,23 @@ def __init__( self.decay_steps = default_ds self.decay_rate = np.exp( np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps) - ) + ).item() if decay_rate is not None: self.decay_rate = decay_rate self.min_lr = self.stop_lr - def value(self, step: int) -> np.float64: + def value(self, step: int | Array) -> Array: """Get the learning rate at the given step.""" - step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) - if step_lr < self.min_lr: - step_lr = self.min_lr + if not array_api_compat.is_array_api_obj(step): + step = np.asarray(step) + xp = array_api_compat.array_namespace(step) + step_lr = self.start_lr * xp.pow( + xp.asarray(self.decay_rate, device=array_api_compat.device(step)), + xp.astype(step // self.decay_steps, xp.float64), + ) + # the original implementation `if step_lr < self.min_lr:` + # will cause a dynamic graph which is unsupported in JAX JIT + step_lr = xp.clip(step_lr, self.min_lr, None) return step_lr @@ -128,12 +140,24 @@ def __init__( super().__init__(start_lr, stop_lr, stop_steps, **kwargs) self.lr_min_factor = stop_lr / start_lr - def value(self, step: int) -> np.float64: - if step >= self.stop_steps: - return self.start_lr * self.lr_min_factor - return self.start_lr * ( + def value(self, step: int | Array) -> Array: + if not array_api_compat.is_array_api_obj(step): + step = np.asarray(step) + xp = array_api_compat.array_namespace(step) + min_lr = self.start_lr * self.lr_min_factor + step_lr = self.start_lr * ( self.lr_min_factor + 0.5 * (1 - self.lr_min_factor) - * (1 + np.cos(np.pi * (step / self.stop_steps))) + * ( + 1 + + xp.cos( + xp.asarray( + xp.pi * (xp.astype(step, xp.float64) / self.stop_steps), + device=array_api_compat.device(step), + ) + ) + ) ) + step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr) + return step_lr diff --git a/pyproject.toml b/pyproject.toml index 5b089c4558..0adba2729a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,8 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", - 'array-api-strict>=2,!=2.1.1;python_version>="3.9"', + # to support Array API 2024.12 + 'array-api-strict>=2.2;python_version>="3.9"', ] docs = [ "sphinx>=3.1.1", diff --git a/source/tests/array_api_strict/__init__.py b/source/tests/array_api_strict/__init__.py index 27f15682e0..d43f0d7353 100644 --- a/source/tests/array_api_strict/__init__.py +++ b/source/tests/array_api_strict/__init__.py @@ -5,4 +5,4 @@ # this is the default version in the latest array_api_strict, # but in old versions it may be 2022.12 -array_api_strict.set_array_api_strict_flags(api_version="2023.12") +array_api_strict.set_array_api_strict_flags(api_version="2024.12") diff --git a/source/tests/consistent/test_learning_rate.py b/source/tests/consistent/test_learning_rate.py new file mode 100644 index 0000000000..5767f3165e --- /dev/null +++ b/source/tests/consistent/test_learning_rate.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import sys +import unittest + +import numpy as np + +from deepmd.dpmodel.array_api import ( + Array, +) +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.utils.learning_rate import ( + BaseLR, +) + +from .common import ( + INSTALLED_ARRAY_API_STRICT, + INSTALLED_JAX, + INSTALLED_PT, + parameterized, +) + +if INSTALLED_PT: + from deepmd.pt.utils.utils import ( + to_torch_tensor, + ) + +if INSTALLED_JAX: + from deepmd.jax.env import ( + jnp, + ) +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict as xp + + +@parameterized( + ( + { + "type": "exp", + "start_lr": 1e-3, + "stop_lr": 1e-8, + "decay_steps": 1000, + "stop_steps": 1000000, + }, + { + "type": "cosine", + "start_lr": 1e-3, + "stop_lr": 1e-8, + "decay_steps": 1000, + "stop_steps": 1000000, + }, + ), +) +class TestLearningRateConsistent(unittest.TestCase): + def setUp(self) -> None: + (lr_param,) = self.param + self.lr = BaseLR(**lr_param) + self.step = 500000 + self.ref = self.lr.value(self.step) + + def compare_test_with_ref(self, step: Array) -> None: + test = self.lr.value(step) + np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10) + + def compare_numpy_with_ref(self, step: Array) -> None: + self.compare_test_with_ref(np.asarray(step)) + + @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") + def test_pt_consistent_with_ref(self) -> None: + self.compare_test_with_ref(to_torch_tensor(self.step)) + + @unittest.skipUnless( + INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed" + ) + @unittest.skipUnless( + sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8" + ) + def test_array_api_strict(self) -> None: + self.compare_test_with_ref(xp.asarray(self.step)) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_consistent_with_ref(self) -> None: + self.compare_test_with_ref(jnp.array(self.step))