Skip to content

Commit 0a807cf

Browse files
chore(pd): sync get_lr from pt to pd (#5144)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Updated learning rate scheduling implementation with enhanced type annotations and more flexible parameter handling for improved code clarity. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 854dca8 commit 0a807cf

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

deepmd/pd/train/training.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
from deepmd.common import (
3131
symlink_prefix_files,
3232
)
33+
from deepmd.dpmodel.utils.learning_rate import (
34+
BaseLR,
35+
)
3336
from deepmd.loggers.training import (
3437
format_training_message,
3538
format_training_message_per_task,
@@ -62,9 +65,6 @@
6265
SAMPLER_RECORD,
6366
enable_prim,
6467
)
65-
from deepmd.pd.utils.learning_rate import (
66-
LearningRateExp,
67-
)
6868
from deepmd.pd.utils.stat import (
6969
make_stat_input,
7070
)
@@ -238,13 +238,10 @@ def get_sample():
238238
_stat_file_path.root.close()
239239
return get_sample
240240

241-
def get_lr(lr_params):
242-
assert lr_params.get("type", "exp") == "exp", (
243-
"Only learning rate `exp` is supported!"
244-
)
241+
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
245242
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
246-
lr_exp = LearningRateExp(**lr_params)
247-
return lr_exp
243+
lr_schedule = BaseLR(**lr_params)
244+
return lr_schedule
248245

249246
# Optimizer
250247
if self.multi_task and training_params.get("optim_dict", None) is not None:

0 commit comments

Comments
 (0)