Skip to content

Commit 84f3953

Browse files
committed
refactor: unify learning rate schedulers with array API
- Refactor BaseLR in dpmodel to use array_api_compat for backend-agnostic implementation - Consolidate learning rate logic from TF/PT/PD backends into unified dpmodel layer - Use array API operations (xp.where, xp.clip, etc.) for JIT compatibility across backends - Add warmup support (warmup_steps, warmup_ratio, warmup_start_factor) during refactoring - Add stop_ratio parameter as alternative to stop_lr for flexible configuration - Implement mutual exclusion validation for stop_lr/stop_ratio and warmup_steps/warmup_ratio - Update all backends to use unified BaseLR implementation - Add comprehensive consistency tests across NumPy/PyTorch/JAX/array_api_strict backends
1 parent 2a9667e commit 84f3953

File tree

21 files changed

+1074
-319
lines changed

21 files changed

+1074
-319
lines changed

deepmd/dpmodel/utils/learning_rate.py

Lines changed: 286 additions & 52 deletions
Large diffs are not rendered by default.

deepmd/pd/train/training.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def get_sample():
239239
return get_sample
240240

241241
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
242-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
242+
lr_params["num_steps"] = self.num_steps
243243
lr_schedule = BaseLR(**lr_params)
244244
return lr_schedule
245245

@@ -387,11 +387,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
387387
)
388388

389389
# Learning rate
390-
self.warmup_steps = training_params.get("warmup_steps", 0)
391390
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
392-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
393-
"Warm up steps must be less than total training steps!"
394-
)
395391
if self.multi_task and config.get("learning_rate_dict", None) is not None:
396392
self.lr_exp = {}
397393
for model_key in self.model_keys:
@@ -580,18 +576,13 @@ def single_model_finetune(
580576

581577
# TODO add lr warmups for multitask
582578
# author: iProzd
583-
def warm_up_linear(step, warmup_steps):
584-
if step < warmup_steps:
585-
return step / warmup_steps
586-
else:
587-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
588-
589579
# TODO add optimizers for multitask
590580
# author: iProzd
591581
if self.opt_type == "Adam":
592582
self.scheduler = paddle.optimizer.lr.LambdaDecay(
593583
learning_rate=self.lr_exp.start_lr,
594-
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
584+
lr_lambda=lambda step: self.lr_exp.value(step + self.start_step)
585+
/ self.lr_exp.start_lr,
595586
)
596587
self.optimizer = paddle.optimizer.Adam(
597588
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
@@ -755,10 +746,7 @@ def step(_step_id, task_key="Default") -> None:
755746
fout1.flush()
756747
if self.opt_type == "Adam":
757748
cur_lr = self.scheduler.get_lr()
758-
if _step_id < self.warmup_steps:
759-
pref_lr = _lr.start_lr
760-
else:
761-
pref_lr = cur_lr
749+
pref_lr = cur_lr
762750

763751
# disable synchronization in forward-backward manually
764752
# as derivatives exist in model forward

deepmd/pd/utils/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,16 +239,17 @@ def to_numpy_array(
239239
):
240240
if xx is None:
241241
return None
242-
assert xx is not None
242+
if isinstance(xx, (float, int)):
243+
return np.array(xx)
244+
if isinstance(xx, np.ndarray):
245+
return xx
243246
# Create a reverse mapping of PD_PRECISION_DICT
244247
reverse_precision_dict = {v: k for k, v in PD_PRECISION_DICT.items()}
245248
# Use the reverse mapping to find keys with the desired value
246249
prec = reverse_precision_dict.get(xx.dtype, None)
247250
prec = NP_PRECISION_DICT.get(prec, np.float64)
248251
if prec is None:
249252
raise ValueError(f"unknown precision {xx.dtype}")
250-
if isinstance(xx, np.ndarray):
251-
return xx.astype(prec)
252253
if xx.dtype == paddle.bfloat16:
253254
xx = xx.astype(paddle.get_default_dtype())
254255
return xx.numpy().astype(prec)

deepmd/pt/train/training.py

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def get_sample() -> Any:
273273
return get_sample
274274

275275
def get_lr(lr_params: dict[str, Any]) -> BaseLR:
276-
lr_params["stop_steps"] = self.num_steps - self.warmup_steps
276+
lr_params["num_steps"] = self.num_steps
277277
lr_schedule = BaseLR(**lr_params)
278278
return lr_schedule
279279

@@ -431,27 +431,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
431431
)
432432

433433
# Learning rate
434-
warmup_steps = training_params.get("warmup_steps", None)
435-
warmup_ratio = training_params.get("warmup_ratio", None)
436-
if warmup_steps is not None:
437-
self.warmup_steps = warmup_steps
438-
elif warmup_ratio is not None:
439-
if not 0 <= warmup_ratio < 1:
440-
raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}")
441-
self.warmup_steps = int(warmup_ratio * self.num_steps)
442-
if self.warmup_steps == 0 and warmup_ratio > 0:
443-
log.warning(
444-
f"warmup_ratio {warmup_ratio} results in 0 warmup steps "
445-
f"due to truncation. Consider using a larger ratio or "
446-
f"specify warmup_steps directly."
447-
)
448-
else:
449-
self.warmup_steps = 0
450-
self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0)
451434
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
452-
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
453-
"Warm up steps must be less than total training steps!"
454-
)
455435
if self.multi_task and config.get("learning_rate_dict", None) is not None:
456436
self.lr_exp = {}
457437
for model_key in self.model_keys:
@@ -697,14 +677,6 @@ def single_model_finetune(
697677

698678
# TODO add lr warmups for multitask
699679
# author: iProzd
700-
def warm_up_linear(step: int, warmup_steps: int) -> float:
701-
if step < warmup_steps:
702-
return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * (
703-
step / warmup_steps
704-
)
705-
else:
706-
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr
707-
708680
# TODO add optimizers for multitask
709681
# author: iProzd
710682
if self.opt_type in ["Adam", "AdamW"]:
@@ -725,7 +697,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
725697
self.optimizer.load_state_dict(optimizer_state_dict)
726698
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
727699
self.optimizer,
728-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
700+
lambda step: self.lr_exp.value(step + self.start_step)
701+
/ self.lr_exp.start_lr,
729702
)
730703
elif self.opt_type == "LKF":
731704
self.optimizer = LKFOptimizer(
@@ -748,7 +721,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
748721
self.optimizer.load_state_dict(optimizer_state_dict)
749722
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
750723
self.optimizer,
751-
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
724+
lambda step: self.lr_exp.value(step + self.start_step)
725+
/ self.lr_exp.start_lr,
752726
)
753727
else:
754728
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
@@ -822,10 +796,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
822796
fout1.flush()
823797
if self.opt_type in ["Adam", "AdamW", "AdaMuon"]:
824798
cur_lr = self.scheduler.get_last_lr()[0]
825-
if _step_id < self.warmup_steps:
826-
pref_lr = _lr.start_lr
827-
else:
828-
pref_lr = cur_lr
799+
pref_lr = cur_lr
829800
model_pred, loss, more_loss = self.wrapper(
830801
**input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key
831802
)

deepmd/pt/utils/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,15 @@ def to_numpy_array(xx: None) -> None: ...
227227

228228

229229
def to_numpy_array(
230-
xx: torch.Tensor | None,
230+
xx: torch.Tensor | np.ndarray | float | None,
231231
) -> np.ndarray | None:
232232
if xx is None:
233233
return None
234-
assert xx is not None
234+
if isinstance(xx, (float, int)):
235+
return np.array(xx)
236+
if isinstance(xx, np.ndarray):
237+
return xx
238+
assert isinstance(xx, torch.Tensor)
235239
# Create a reverse mapping of PT_PRECISION_DICT
236240
reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()}
237241
# Use the reverse mapping to find keys with the desired value

deepmd/tf/fit/dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
388388
----------
389389
loss : dict
390390
the loss dict
391-
lr : LearningRateExp
391+
lr : LearningRateSchedule
392392
the learning rate
393393
394394
Returns

deepmd/tf/fit/dos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
655655
----------
656656
loss : dict
657657
the loss dict
658-
lr : LearningRateExp
658+
lr : LearningRateSchedule
659659
the learning rate
660660
661661
Returns

deepmd/tf/fit/ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
856856
----------
857857
loss : dict
858858
The loss function parameters.
859-
lr : LearningRateExp
859+
lr : LearningRateSchedule
860860
The learning rate.
861861
862862
Returns

deepmd/tf/fit/fitting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
7373
----------
7474
loss : dict
7575
the loss dict
76-
lr : LearningRateExp
76+
lr : LearningRateSchedule
7777
the learning rate
7878
7979
Returns

deepmd/tf/fit/polar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ def get_loss(self, loss: dict, lr) -> Loss:
863863
----------
864864
loss : dict
865865
the loss dict
866-
lr : LearningRateExp
866+
lr : LearningRateSchedule
867867
the learning rate
868868
869869
Returns

0 commit comments

Comments
 (0)