Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,23 @@ def get_lr(lr_params: dict[str, Any]) -> LearningRateExp:
)

# Learning rate
self.warmup_steps = training_params.get("warmup_steps", 0)
warmup_steps = training_params.get("warmup_steps", None)
warmup_ratio = training_params.get("warmup_ratio", None)
if warmup_steps is not None:
self.warmup_steps = warmup_steps
elif warmup_ratio is not None:
if not 0 <= warmup_ratio < 1:
raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}")
self.warmup_steps = int(warmup_ratio * self.num_steps)
if self.warmup_steps == 0 and warmup_ratio > 0:
log.warning(
f"warmup_ratio {warmup_ratio} results in 0 warmup steps "
f"due to truncation. Consider using a larger ratio or "
f"specify warmup_steps directly."
)
else:
self.warmup_steps = 0
self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0)
self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0)
assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, (
"Warm up steps must be less than total training steps!"
Expand Down Expand Up @@ -668,7 +684,9 @@ def single_model_finetune(
# author: iProzd
def warm_up_linear(step: int, warmup_steps: int) -> float:
if step < warmup_steps:
return step / warmup_steps
return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * (
step / warmup_steps
)
else:
return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr

Expand Down
24 changes: 24 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3223,6 +3223,17 @@ def training_args(
"the learning rate begins at zero and progressively increases linearly to `start_lr`, "
"rather than starting directly from `start_lr`"
)
doc_warmup_ratio = (
"The ratio of warmup steps to total training steps. "
"The actual number of warmup steps is calculated as `warmup_ratio * numb_steps`. "
"Valid values are in the range [0.0, 1.0). "
"If `warmup_steps` is set, this option will be ignored."
)
doc_warmup_start_factor = (
"The factor of start learning rate to the target learning rate during warmup. "
"The warmup learning rate will linearly increase from `warmup_start_factor * start_lr` to `start_lr`. "
"Default is 0.0, meaning the learning rate starts from zero."
)
doc_gradient_max_norm = (
"Clips the gradient norm to a maximum value. "
"If the gradient norm exceeds this value, it will be clipped to this limit. "
Expand Down Expand Up @@ -3336,6 +3347,19 @@ def training_args(
optional=True,
doc=doc_only_pt_supported + doc_warmup_steps,
),
Argument(
"warmup_ratio",
float,
optional=True,
doc=doc_only_pt_supported + doc_warmup_ratio,
),
Argument(
"warmup_start_factor",
float,
optional=True,
default=0.0,
doc=doc_only_pt_supported + doc_warmup_start_factor,
),
Argument(
"gradient_max_norm",
float,
Expand Down