Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
self.restart_training = restart_model is not None
model_params = config["model"]
training_params = config["training"]
optimizer_params = config.get("optimizer") or {}
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
self.finetune_update_stat = False
Expand Down Expand Up @@ -149,13 +150,18 @@ def __init__(
self.lcurve_should_print_header = True

def get_opt_param(params):
opt_type = params.get("opt_type", "Adam")
opt_type = params.get("type", "Adam")
if opt_type != "Adam":
raise ValueError(f"Not supported optimizer type '{opt_type}'")
opt_param = {
"kf_blocksize": params.get("kf_blocksize", 5120),
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
"adam_beta1": params.get("adam_beta1", 0.9),
"adam_beta2": params.get("adam_beta2", 0.999),
"weight_decay": params.get("weight_decay", 0.0),
}
return opt_type, opt_param

Expand Down Expand Up @@ -259,7 +265,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
self.optim_dict[model_key]
)
else:
self.opt_type, self.opt_param = get_opt_param(training_params)
self.opt_type, self.opt_param = get_opt_param(optimizer_params)

# loss_param_tmp for Hessian activation
loss_param_tmp = None
Expand Down Expand Up @@ -594,7 +600,11 @@ def warm_up_linear(step, warmup_steps):
lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps),
)
self.optimizer = paddle.optimizer.Adam(
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
learning_rate=self.scheduler,
parameters=self.wrapper.parameters(),
beta1=float(self.opt_param["adam_beta1"]),
beta2=float(self.opt_param["adam_beta2"]),
weight_decay=float(self.opt_param["weight_decay"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.set_state_dict(optimizer_state_dict)
Expand Down
78 changes: 60 additions & 18 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(
self.restart_training = restart_model is not None
model_params = config["model"]
training_params = config["training"]
optimizer_params = config.get("optimizer") or {}
self.multi_task = "model_dict" in model_params
self.finetune_links = finetune_links
self.finetune_update_stat = False
Expand Down Expand Up @@ -157,7 +158,24 @@ def __init__(
self.lcurve_should_print_header = True

def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
opt_type = params.get("opt_type", "Adam")
opt_type = params.get("type", "Adam")
if opt_type == "Adam":
default_adam_beta2 = 0.999
default_weight_decay = 0.0
elif opt_type == "AdamW":
default_adam_beta2 = 0.999
default_weight_decay = 0.0
elif opt_type == "LKF":
default_adam_beta2 = 0.95
default_weight_decay = 0.001
elif opt_type == "AdaMuon":
default_adam_beta2 = 0.95
default_weight_decay = 0.001
elif opt_type == "HybridMuon":
default_adam_beta2 = 0.95
default_weight_decay = 0.001
else:
raise ValueError(f"Not supported optimizer type '{opt_type}'")
opt_param = {
# LKF parameters
"kf_blocksize": params.get("kf_blocksize", 5120),
Expand All @@ -166,11 +184,11 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
# Common parameters
"weight_decay": params.get("weight_decay", 0.001),
"weight_decay": params.get("weight_decay", default_weight_decay),
# Muon/AdaMuon parameters
"momentum": params.get("momentum", 0.95),
"adam_beta1": params.get("adam_beta1", 0.9),
"adam_beta2": params.get("adam_beta2", 0.95),
"adam_beta2": params.get("adam_beta2", default_adam_beta2),
"lr_adjust": params.get("lr_adjust", 10.0),
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
"muon_2d_only": params.get("muon_2d_only", True),
Expand Down Expand Up @@ -299,7 +317,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
self.optim_dict[model_key]
)
else:
self.opt_type, self.opt_param = get_opt_param(training_params)
self.opt_type, self.opt_param = get_opt_param(optimizer_params)

# loss_param_tmp for Hessian activation
loss_param_tmp = None
Expand Down Expand Up @@ -712,20 +730,38 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:

# TODO add optimizers for multitask
# author: iProzd
if self.opt_type in ["Adam", "AdamW"]:
if self.opt_type == "Adam":
self.optimizer = torch.optim.Adam(
self.wrapper.parameters(),
lr=self.lr_exp.start_lr,
fused=False if DEVICE.type == "cpu" else True,
)
else:
self.optimizer = torch.optim.AdamW(
self.wrapper.parameters(),
lr=self.lr_exp.start_lr,
weight_decay=float(self.opt_param["weight_decay"]),
fused=False if DEVICE.type == "cpu" else True,
)
if self.opt_type == "Adam":
adam_betas = (
float(self.opt_param["adam_beta1"]),
float(self.opt_param["adam_beta2"]),
)
weight_decay = float(self.opt_param["weight_decay"])
self.optimizer = torch.optim.Adam(
self.wrapper.parameters(),
lr=self.lr_exp.start_lr,
betas=adam_betas,
weight_decay=weight_decay,
fused=False if DEVICE.type == "cpu" else True,
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
)
elif self.opt_type == "AdamW":
adam_betas = (
float(self.opt_param["adam_beta1"]),
float(self.opt_param["adam_beta2"]),
)
weight_decay = float(self.opt_param["weight_decay"])
self.optimizer = torch.optim.AdamW(
self.wrapper.parameters(),
lr=self.lr_exp.start_lr,
betas=adam_betas,
weight_decay=weight_decay,
fused=False if DEVICE.type == "cpu" else True,
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
Expand All @@ -749,6 +785,12 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
lr_adjust=float(self.opt_param["lr_adjust"]),
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
)
if optimizer_state_dict is not None and self.restart_training:
self.optimizer.load_state_dict(optimizer_state_dict)
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
self.optimizer,
lambda step: warm_up_linear(step + self.start_step, self.warmup_steps),
)
elif self.opt_type == "HybridMuon":
self.optimizer = HybridMuonOptimizer(
self.wrapper.parameters(),
Expand Down
35 changes: 32 additions & 3 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ def get_lr_and_coef(lr_param):
# learning rate
lr_param = jdata["learning_rate"]
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
# optimizer
optimizer_param = jdata.get("optimizer") or {}
self.optimizer_type = optimizer_param.get("type", "Adam")
self.optimizer_beta1 = float(optimizer_param.get("adam_beta1", 0.9))
self.optimizer_beta2 = float(optimizer_param.get("adam_beta2", 0.999))
self.optimizer_weight_decay = float(optimizer_param.get("weight_decay", 0.0))
if self.optimizer_type != "Adam":
raise RuntimeError(
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
)
if self.optimizer_weight_decay != 0.0:
raise RuntimeError(
"TensorFlow Adam optimizer does not support weight_decay. "
"Set optimizer/weight_decay to 0."
)
# loss
# infer loss type by fitting_type
loss_param = jdata.get("loss", {})
Expand Down Expand Up @@ -306,17 +321,31 @@ def _build_network(self, data, suffix="") -> None:
log.info("built network")

def _build_optimizer(self):
if self.optimizer_type != "Adam":
raise RuntimeError(
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
)
if self.run_opt.is_distrib:
if self.scale_lr_coef > 1.0:
log.info("Scale learning rate by coef: %f", self.scale_lr_coef)
optimizer = tf.train.AdamOptimizer(
self.learning_rate * self.scale_lr_coef
self.learning_rate * self.scale_lr_coef,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)
else:
optimizer = tf.train.AdamOptimizer(self.learning_rate)
optimizer = tf.train.AdamOptimizer(
self.learning_rate,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
optimizer = tf.train.AdamOptimizer(
learning_rate=self.learning_rate,
beta1=self.optimizer_beta1,
beta2=self.optimizer_beta2,
)

if self.mixed_prec is not None:
_TF_VERSION = Version(TF_VERSION)
Expand Down
Loading