diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 4e5fea081f..dd0fbdc94b 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -30,6 +30,9 @@ from deepmd.common import ( symlink_prefix_files, ) +from deepmd.dpmodel.utils.learning_rate import ( + BaseLR, +) from deepmd.loggers.training import ( format_training_message, format_training_message_per_task, @@ -62,9 +65,6 @@ SAMPLER_RECORD, enable_prim, ) -from deepmd.pd.utils.learning_rate import ( - LearningRateExp, -) from deepmd.pd.utils.stat import ( make_stat_input, ) @@ -238,13 +238,10 @@ def get_sample(): _stat_file_path.root.close() return get_sample - def get_lr(lr_params): - assert lr_params.get("type", "exp") == "exp", ( - "Only learning rate `exp` is supported!" - ) + def get_lr(lr_params: dict[str, Any]) -> BaseLR: lr_params["stop_steps"] = self.num_steps - self.warmup_steps - lr_exp = LearningRateExp(**lr_params) - return lr_exp + lr_schedule = BaseLR(**lr_params) + return lr_schedule # Optimizer if self.multi_task and training_params.get("optim_dict", None) is not None: