diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index dd0fbdc94b..8c89c7bc6c 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -115,6 +115,7 @@ def __init__( self.restart_training = restart_model is not None model_params = config["model"] training_params = config["training"] + optimizer_params = config.get("optimizer") or {} self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -149,13 +150,18 @@ def __init__( self.lcurve_should_print_header = True def get_opt_param(params): - opt_type = params.get("opt_type", "Adam") + opt_type = params.get("type", "Adam") + if opt_type != "Adam": + raise ValueError(f"Not supported optimizer type '{opt_type}'") opt_param = { "kf_blocksize": params.get("kf_blocksize", 5120), "kf_start_pref_e": params.get("kf_start_pref_e", 1), "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), "kf_start_pref_f": params.get("kf_start_pref_f", 1), "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), + "adam_beta1": params.get("adam_beta1", 0.9), + "adam_beta2": params.get("adam_beta2", 0.999), + "weight_decay": params.get("weight_decay", 0.0), } return opt_type, opt_param @@ -259,7 +265,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.optim_dict[model_key] ) else: - self.opt_type, self.opt_param = get_opt_param(training_params) + self.opt_type, self.opt_param = get_opt_param(optimizer_params) # loss_param_tmp for Hessian activation loss_param_tmp = None @@ -594,7 +600,11 @@ def warm_up_linear(step, warmup_steps): lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps), ) self.optimizer = paddle.optimizer.Adam( - learning_rate=self.scheduler, parameters=self.wrapper.parameters() + learning_rate=self.scheduler, + parameters=self.wrapper.parameters(), + beta1=float(self.opt_param["adam_beta1"]), + beta2=float(self.opt_param["adam_beta2"]), + weight_decay=float(self.opt_param["weight_decay"]), ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.set_state_dict(optimizer_state_dict) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 20497a0ceb..005c3ddbdc 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -125,6 +125,7 @@ def __init__( self.restart_training = restart_model is not None model_params = config["model"] training_params = config["training"] + optimizer_params = config.get("optimizer") or {} self.multi_task = "model_dict" in model_params self.finetune_links = finetune_links self.finetune_update_stat = False @@ -157,7 +158,24 @@ def __init__( self.lcurve_should_print_header = True def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: - opt_type = params.get("opt_type", "Adam") + opt_type = params.get("type", "Adam") + if opt_type == "Adam": + default_adam_beta2 = 0.999 + default_weight_decay = 0.0 + elif opt_type == "AdamW": + default_adam_beta2 = 0.999 + default_weight_decay = 0.0 + elif opt_type == "LKF": + default_adam_beta2 = 0.95 + default_weight_decay = 0.001 + elif opt_type == "AdaMuon": + default_adam_beta2 = 0.95 + default_weight_decay = 0.001 + elif opt_type == "HybridMuon": + default_adam_beta2 = 0.95 + default_weight_decay = 0.001 + else: + raise ValueError(f"Not supported optimizer type '{opt_type}'") opt_param = { # LKF parameters "kf_blocksize": params.get("kf_blocksize", 5120), @@ -166,11 +184,11 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "kf_start_pref_f": params.get("kf_start_pref_f", 1), "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), # Common parameters - "weight_decay": params.get("weight_decay", 0.001), + "weight_decay": params.get("weight_decay", default_weight_decay), # Muon/AdaMuon parameters "momentum": params.get("momentum", 0.95), "adam_beta1": params.get("adam_beta1", 0.9), - "adam_beta2": params.get("adam_beta2", 0.95), + "adam_beta2": params.get("adam_beta2", default_adam_beta2), "lr_adjust": params.get("lr_adjust", 10.0), "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), "muon_2d_only": params.get("muon_2d_only", True), @@ -299,7 +317,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: self.optim_dict[model_key] ) else: - self.opt_type, self.opt_param = get_opt_param(training_params) + self.opt_type, self.opt_param = get_opt_param(optimizer_params) # loss_param_tmp for Hessian activation loss_param_tmp = None @@ -712,20 +730,38 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: # TODO add optimizers for multitask # author: iProzd - if self.opt_type in ["Adam", "AdamW"]: - if self.opt_type == "Adam": - self.optimizer = torch.optim.Adam( - self.wrapper.parameters(), - lr=self.lr_exp.start_lr, - fused=False if DEVICE.type == "cpu" else True, - ) - else: - self.optimizer = torch.optim.AdamW( - self.wrapper.parameters(), - lr=self.lr_exp.start_lr, - weight_decay=float(self.opt_param["weight_decay"]), - fused=False if DEVICE.type == "cpu" else True, - ) + if self.opt_type == "Adam": + adam_betas = ( + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), + ) + weight_decay = float(self.opt_param["weight_decay"]) + self.optimizer = torch.optim.Adam( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + betas=adam_betas, + weight_decay=weight_decay, + fused=False if DEVICE.type == "cpu" else True, + ) + if optimizer_state_dict is not None and self.restart_training: + self.optimizer.load_state_dict(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) + elif self.opt_type == "AdamW": + adam_betas = ( + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), + ) + weight_decay = float(self.opt_param["weight_decay"]) + self.optimizer = torch.optim.AdamW( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + betas=adam_betas, + weight_decay=weight_decay, + fused=False if DEVICE.type == "cpu" else True, + ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -749,6 +785,12 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: lr_adjust=float(self.opt_param["lr_adjust"]), lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) + if optimizer_state_dict is not None and self.restart_training: + self.optimizer.load_state_dict(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) elif self.opt_type == "HybridMuon": self.optimizer = HybridMuonOptimizer( self.wrapper.parameters(), diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 4af59fd290..3f13089dfd 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -120,6 +120,21 @@ def get_lr_and_coef(lr_param): # learning rate lr_param = jdata["learning_rate"] self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param) + # optimizer + optimizer_param = jdata.get("optimizer") or {} + self.optimizer_type = optimizer_param.get("type", "Adam") + self.optimizer_beta1 = float(optimizer_param.get("adam_beta1", 0.9)) + self.optimizer_beta2 = float(optimizer_param.get("adam_beta2", 0.999)) + self.optimizer_weight_decay = float(optimizer_param.get("weight_decay", 0.0)) + if self.optimizer_type != "Adam": + raise RuntimeError( + f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend." + ) + if self.optimizer_weight_decay != 0.0: + raise RuntimeError( + "TensorFlow Adam optimizer does not support weight_decay. " + "Set optimizer/weight_decay to 0." + ) # loss # infer loss type by fitting_type loss_param = jdata.get("loss", {}) @@ -306,17 +321,31 @@ def _build_network(self, data, suffix="") -> None: log.info("built network") def _build_optimizer(self): + if self.optimizer_type != "Adam": + raise RuntimeError( + f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend." + ) if self.run_opt.is_distrib: if self.scale_lr_coef > 1.0: log.info("Scale learning rate by coef: %f", self.scale_lr_coef) optimizer = tf.train.AdamOptimizer( - self.learning_rate * self.scale_lr_coef + self.learning_rate * self.scale_lr_coef, + beta1=self.optimizer_beta1, + beta2=self.optimizer_beta2, ) else: - optimizer = tf.train.AdamOptimizer(self.learning_rate) + optimizer = tf.train.AdamOptimizer( + self.learning_rate, + beta1=self.optimizer_beta1, + beta2=self.optimizer_beta2, + ) optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) else: - optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) + optimizer = tf.train.AdamOptimizer( + learning_rate=self.learning_rate, + beta1=self.optimizer_beta1, + beta2=self.optimizer_beta2, + ) if self.mixed_prec is not None: _TF_VERSION = Version(TF_VERSION) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8c20bb8bf4..344e9329d9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2565,6 +2565,292 @@ def learning_rate_args(fold_subdoc: bool = False) -> Argument: ) +# --- Optimizer configurations: --- # +opt_args_plugin = ArgsPlugin() + + +@opt_args_plugin.register("Adam") +def optimizer_adam() -> list[Argument]: + doc_adam_beta1 = "Adam beta1 coefficient for first moment decay." + doc_adam_beta2 = "Adam beta2 coefficient for second moment decay." + doc_weight_decay = ( + "Weight decay coefficient for Adam. In PyTorch and Paddle, this is an L2 " + "penalty applied to gradients. TensorFlow does not support weight_decay and " + "requires this value to be 0." + ) + return [ + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_adam_beta1, + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.999, + doc=doc_adam_beta2, + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.0, + doc=doc_weight_decay, + ), + ] + + +@opt_args_plugin.register("AdamW", doc=doc_only_pt_supported) +def optimizer_adamw() -> list[Argument]: + doc_adam_beta1 = "AdamW beta1 coefficient for first moment decay." + doc_adam_beta2 = "AdamW beta2 coefficient for second moment decay." + doc_weight_decay = ( + "Decoupled weight decay coefficient for AdamW optimizer (PyTorch only)." + ) + return [ + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + doc_adam_beta1, + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.999, + doc=doc_only_pt_supported + doc_adam_beta2, + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.0, + doc=doc_only_pt_supported + doc_weight_decay, + ), + ] + + +@opt_args_plugin.register("LKF", doc=doc_only_pt_supported) +def optimizer_lkf() -> list[Argument]: + doc_kf_blocksize = "The blocksize for the Kalman filter." + doc_kf_start_pref_e = ( + "The prefactor of energy loss at the start of Kalman filter updates." + ) + doc_kf_limit_pref_e = ( + "The prefactor of energy loss at the end of training for Kalman filter updates." + ) + doc_kf_start_pref_f = ( + "The prefactor of force loss at the start of Kalman filter updates." + ) + doc_kf_limit_pref_f = ( + "The prefactor of force loss at the end of training for Kalman filter updates." + ) + return [ + Argument( + "kf_blocksize", + int, + optional=True, + default=5120, + doc=doc_only_pt_supported + doc_kf_blocksize, + ), + Argument( + "kf_start_pref_e", + float, + optional=True, + default=1.0, + doc=doc_only_pt_supported + doc_kf_start_pref_e, + ), + Argument( + "kf_limit_pref_e", + float, + optional=True, + default=1.0, + doc=doc_only_pt_supported + doc_kf_limit_pref_e, + ), + Argument( + "kf_start_pref_f", + float, + optional=True, + default=1.0, + doc=doc_only_pt_supported + doc_kf_start_pref_f, + ), + Argument( + "kf_limit_pref_f", + float, + optional=True, + default=1.0, + doc=doc_only_pt_supported + doc_kf_limit_pref_f, + ), + ] + + +@opt_args_plugin.register("AdaMuon", doc=doc_only_pt_supported) +def optimizer_adamuon() -> list[Argument]: + return [ + Argument( + "momentum", + float, + optional=True, + default=0.95, + alias=["muon_momentum"], + doc=doc_only_pt_supported + "Momentum coefficient for AdaMuon optimizer.", + ), + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + "Adam beta1 coefficient for AdaMuon optimizer.", + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + "Adam beta2 coefficient for AdaMuon optimizer.", + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.001, + doc=doc_only_pt_supported + + "Weight decay coefficient. Applied only to >=2D parameters (AdaMuon path).", + ), + Argument( + "lr_adjust", + float, + optional=True, + default=10.0, + doc=doc_only_pt_supported + + "Learning rate adjustment factor for Adam (1D params). " + "If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. " + "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1.0, m/n))), Adam uses lr/lr_adjust.", + ), + Argument( + "lr_adjust_coeff", + float, + optional=True, + default=0.2, + doc=doc_only_pt_supported + + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", + ), + ] + + +@opt_args_plugin.register("HybridMuon", doc=doc_only_pt_supported) +def optimizer_hybrid_muon() -> list[Argument]: + return [ + Argument( + "momentum", + float, + optional=True, + default=0.95, + alias=["muon_momentum"], + doc=doc_only_pt_supported + + "Momentum coefficient for HybridMuon optimizer (>=2D params). " + "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.", + ), + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + + "Adam beta1 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + + "Adam beta2 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.001, + doc=doc_only_pt_supported + + "Weight decay coefficient. Applied only to Muon-routed parameters", + ), + Argument( + "lr_adjust", + float, + optional=True, + default=10.0, + doc=doc_only_pt_supported + + "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. " + "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " + "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " + "Default is 10.0 (Adam lr = lr/10).", + ), + Argument( + "lr_adjust_coeff", + float, + optional=True, + default=0.2, + doc=doc_only_pt_supported + + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", + ), + Argument( + "muon_2d_only", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " + + "Parameters with ndim > 2 use Adam without weight decay. " + + "If False, all >=2D parameters use Muon.", + ), + Argument( + "min_2d_dim", + int, + optional=True, + default=1, + alias=["muon_min_2d_dim"], + doc=doc_only_pt_supported + + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " + "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " + "those with min(m, n) < min_2d_dim use Adam fallback. " + "Set to 1 to disable fallback.", + ), + ] + + +def optimizer_variant_type_args() -> Variant: + doc_opt_type = "The type of optimizer to use." + return Variant( + "type", + opt_args_plugin.get_all_argument(), + optional=True, + default_tag="Adam", + doc=doc_opt_type, + ) + + +def optimizer_args(fold_subdoc: bool = False) -> Argument: + doc_optimizer = ( + "The definition of optimizer. Supported optimizer types depend on backend: " + "TensorFlow/Paddle: Adam; PyTorch: Adam, AdamW, LKF, AdaMuon, HybridMuon." + ) + return Argument( + "optimizer", + dict, + [], + [optimizer_variant_type_args()], + optional=True, + doc=doc_optimizer, + fold_subdoc=fold_subdoc, + ) + + # --- Loss configurations: --- # def start_pref(item: str, label: str | None = None, abbr: str | None = None) -> str: if label is None: @@ -3268,8 +3554,6 @@ def training_args( "If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; " "otherwise, a directory containing NumPy binary files are used." ) - doc_opt_type = "The type of optimizer to use." - doc_kf_blocksize = "The blocksize for the Kalman filter." doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." doc_data_dict = "The multiple definition of the data, used in the multi-task mode." doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)." @@ -3396,183 +3680,8 @@ def training_args( doc=doc_only_pd_supported + doc_acc_freq, ), ] - variants = [ - Variant( - "opt_type", - choices=[ - Argument("Adam", dict, [], [], optional=True), - Argument("AdamW", dict, [], [], optional=True), - Argument( - "LKF", - dict, - [ - Argument( - "kf_blocksize", - int, - optional=True, - doc=doc_only_pt_supported + doc_kf_blocksize, - ), - ], - [], - optional=True, - ), - Argument( - "AdaMuon", - dict, - [ - Argument( - "momentum", - float, - optional=True, - default=0.95, - alias=["muon_momentum"], - doc=doc_only_pt_supported - + "Momentum coefficient for AdaMuon optimizer.", - ), - Argument( - "adam_beta1", - float, - optional=True, - default=0.9, - doc=doc_only_pt_supported - + "Adam beta1 coefficient for AdaMuon optimizer.", - ), - Argument( - "adam_beta2", - float, - optional=True, - default=0.95, - doc=doc_only_pt_supported - + "Adam beta2 coefficient for AdaMuon optimizer.", - ), - Argument( - "weight_decay", - float, - optional=True, - default=0.001, - doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to >=2D parameters (AdaMuon path).", - ), - Argument( - "lr_adjust", - float, - optional=True, - default=10.0, - doc=doc_only_pt_supported - + "Learning rate adjustment factor for Adam (1D params). " - "If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. " - "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1.0, m/n))), Adam uses lr/lr_adjust.", - ), - Argument( - "lr_adjust_coeff", - float, - optional=True, - default=0.2, - doc=doc_only_pt_supported - + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", - ), - ], - [], - optional=True, - ), - Argument( - "HybridMuon", - dict, - [ - Argument( - "momentum", - float, - optional=True, - default=0.95, - alias=["muon_momentum"], - doc=doc_only_pt_supported - + "Momentum coefficient for HybridMuon optimizer (>=2D params). " - "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.", - ), - Argument( - "adam_beta1", - float, - optional=True, - default=0.9, - doc=doc_only_pt_supported - + "Adam beta1 coefficient for 1D parameters (biases, norms).", - ), - Argument( - "adam_beta2", - float, - optional=True, - default=0.95, - doc=doc_only_pt_supported - + "Adam beta2 coefficient for 1D parameters (biases, norms).", - ), - Argument( - "weight_decay", - float, - optional=True, - default=0.001, - doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to Muon-routed parameters", - ), - Argument( - "lr_adjust", - float, - optional=True, - default=10.0, - doc=doc_only_pt_supported - + "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. " - "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " - "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " - "Default is 10.0 (Adam lr = lr/10).", - ), - Argument( - "lr_adjust_coeff", - float, - optional=True, - default=0.2, - doc=doc_only_pt_supported - + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", - ), - Argument( - "muon_2d_only", - bool, - optional=True, - default=True, - doc=doc_only_pt_supported - + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " - + "Parameters with ndim > 2 use Adam without weight decay. " - + "If False, all >=2D parameters use Muon.", - ), - Argument( - "min_2d_dim", - int, - optional=True, - default=1, - alias=["muon_min_2d_dim"], - doc=doc_only_pt_supported - + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " - "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " - "those with min(m, n) < min_2d_dim use Adam fallback. " - "Set to 1 to disable fallback.", - ), - ], - [], - optional=True, - doc=doc_only_pt_supported - + "HybridMuon optimizer (DeePMD-kit custom implementation). " - + "This is a Hybrid optimizer that automatically combines Muon and Adam. " - + "For >=2D params: Muon update with Newton-Schulz. " - + "For 1D params: Standard Adam. " - + "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.", - ), - ], - optional=True, - default_tag="Adam", - doc=doc_only_pt_supported + doc_opt_type, - ) - ] - doc_training = "The training options." - return Argument("training", dict, args, variants, doc=doc_training) + return Argument("training", dict, args, [], doc=doc_training) def multi_model_args() -> list[Argument]: @@ -3646,6 +3755,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: return [ model_args(), learning_rate_args(), + optimizer_args(), loss_args(), training_args(multi_task=multi_task), nvnmd_args(), @@ -3654,6 +3764,7 @@ def gen_args(multi_task: bool = False) -> list[Argument]: return [ multi_model_args(), learning_rate_args(fold_subdoc=True), + optimizer_args(fold_subdoc=True), multi_loss_args(), training_args(multi_task=multi_task), nvnmd_args(fold_subdoc=True), diff --git a/doc/model/train-fitting-dos.md b/doc/model/train-fitting-dos.md index fb4a3677e5..0386406262 100644 --- a/doc/model/train-fitting-dos.md +++ b/doc/model/train-fitting-dos.md @@ -16,7 +16,7 @@ $deepmd_source_dir/examples/dos/input.json The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** -Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). To fit the `dos`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`optimizer `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found in the [SE-E2-A guide](train-se-e2-a.md). To fit the `dos`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. ## The fitting Network diff --git a/doc/model/train-fitting-property.md b/doc/model/train-fitting-property.md index be1b63bf6f..4a08c255db 100644 --- a/doc/model/train-fitting-property.md +++ b/doc/model/train-fitting-property.md @@ -14,7 +14,7 @@ $deepmd_source_dir/examples/property/train The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** -Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-atten.md). To fit the `property`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`optimizer `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found in the [SE-Atten guide](train-se-atten.md). To fit the `property`, one needs to modify {ref}`model[standard]/fitting_net ` and {ref}`loss `. ## The fitting Network diff --git a/doc/model/train-fitting-tensor.md b/doc/model/train-fitting-tensor.md index 29c95b2d68..83288c0b35 100644 --- a/doc/model/train-fitting-tensor.md +++ b/doc/model/train-fitting-tensor.md @@ -30,7 +30,7 @@ $deepmd_source_dir/examples/water_tensor/polar/polar_input_torch.json The training and validation data are also provided our examples. But note that **the data provided along with the examples are of limited amount, and should not be used to train a production model.** -Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found [here](train-se-e2-a.md). +Similar to the `input.json` used in `ener` mode, training JSON is also divided into {ref}`model `, {ref}`learning_rate `, {ref}`optimizer `, {ref}`loss ` and {ref}`training `. Most keywords remain the same as `ener` mode, and their meaning can be found in the [SE-E2-A guide](train-se-e2-a.md). To fit a tensor, one needs to modify {ref}`fitting_net ` and {ref}`loss `. ## Theory @@ -103,7 +103,7 @@ The JSON of `polar` type should be provided like }, ``` -- `type` specifies which type of fitting net should be used. It should be either `dipole` or `polar`. Note that `global_polar` mode in version 1.x is already **deprecated** and is merged into `polar`. To specify whether a system is global or atomic, please see [here](train-se-e2-a.md). +- `type` specifies which type of fitting net should be used. It should be either `dipole` or `polar`. Note that `global_polar` mode in version 1.x is already **deprecated** and is merged into `polar`. To specify whether a system is global or atomic, please see the [SE-E2-A guide](train-se-e2-a.md). - `sel_type` is a list specifying which type of atoms have the quantity you want to fit. For example, in the water system, `sel_type` is `[0]` since `0` represents atom `O`. If left unset, all types of atoms will be fitted. - The rest arguments have the same meaning as they do in `ener` mode. @@ -139,7 +139,7 @@ The JSON of `polar` type should be provided like }, ``` -- `type` specifies which type of fitting net should be used. It should be either `dipole` or `polar`. Note that `global_polar` mode in version 1.x is already **deprecated** and is merged into `polar`. To specify whether a system is global or atomic, please see [here](train-se-e2-a.md). +- `type` specifies which type of fitting net should be used. It should be either `dipole` or `polar`. Note that `global_polar` mode in version 1.x is already **deprecated** and is merged into `polar`. To specify whether a system is global or atomic, please see the [SE-E2-A guide](train-se-e2-a.md). - `atom_exclude_types` is a list specifying the which type of atoms have the quantity you want to set to zero. For example, in the water system, `atom_exclude_types` is `[1]` since `1` represents atom `H`. - The rest arguments have the same meaning as they do in `ener` mode. ::: diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index af4b4b31d9..ab98bca11e 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -45,6 +45,22 @@ The {ref}`learning_rate ` section in `input.json` is given as fol lr(t) = start_lr * decay_rate ^ ( t / decay_steps ) ``` +## Optimizer + +The {ref}`optimizer ` section in `input.json` is given as follows + +```json + "optimizer" :{ + "type": "Adam", + "_comment": "that's all" + } +``` + +- TensorFlow/Paddle: only {ref}`Adam ` is supported. +- PyTorch: {ref}`Adam `, {ref}`AdamW `, {ref}`LKF `, {ref}`AdaMuon `, {ref}`HybridMuon `. +- {ref}`adam_beta1 ` and {ref}`adam_beta2 ` control the Adam/AdamW moment decay. +- {ref}`weight_decay ` applies L2 penalty in Adam, while {ref}`weight_decay ` is decoupled in AdamW. TensorFlow does not support weight decay in Adam. + ## Training parameters Other training parameters are given in the {ref}`training ` section. diff --git a/source/tests/pt/model/water/lkf.json b/source/tests/pt/model/water/lkf.json index 4385d02136..b597d10d28 100644 --- a/source/tests/pt/model/water/lkf.json +++ b/source/tests/pt/model/water/lkf.json @@ -42,6 +42,11 @@ "stop_lr": 3.51e-8, "_comment": "that's all" }, + "optimizer": { + "type": "LKF", + "kf_blocksize": 1024, + "_comment": "that's all" + }, "loss": { "type": "ener", "start_pref_e": 0.02, @@ -71,8 +76,6 @@ "disp_file": "lcurve.out", "disp_freq": 1, "save_freq": 1, - "opt_type": "LKF", - "kf_blocksize": 1024, "_comment": "that's all" }, "_comment": "that's all"