diff --git a/deepmd/dpmodel/utils/learning_rate.py b/deepmd/dpmodel/utils/learning_rate.py index 7ea50583e2..92fbf3776a 100644 --- a/deepmd/dpmodel/utils/learning_rate.py +++ b/deepmd/dpmodel/utils/learning_rate.py @@ -29,122 +29,442 @@ def __new__(cls: type, *args: Any, **kwargs: Any) -> Any: return super().__new__(cls) def __init__( - self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any + self, + start_lr: float, + stop_lr: float | None = None, + stop_lr_ratio: float | None = None, + num_steps: int = 100000, + warmup_steps: int = 0, + warmup_ratio: float | None = None, + warmup_start_factor: float = 0.0, + **kwargs: Any, ) -> None: """ - Base class for learning rate schedules. + Base class for learning rate schedules with warmup support. Parameters ---------- - start_lr - The initial learning rate. - stop_lr - The final learning rate. - stop_steps - The total training steps for learning rate scheduler. + start_lr : float + The learning rate at the start of the training (after warmup). + stop_lr : float, optional + The final learning rate at the end of the training. + Mutually exclusive with stop_lr_ratio. + stop_lr_ratio : float, optional + The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_lr_ratio. + Mutually exclusive with stop_lr. + One of stop_lr or stop_lr_ratio must be provided. + num_steps : int + The total training steps (including warmup). + warmup_steps : int, optional + The number of steps for learning rate warmup. + Mutually exclusive with warmup_ratio. Default is 0 (no warmup). + warmup_ratio : float, optional + The ratio of warmup steps to total training steps. + warmup_steps = int(warmup_ratio * num_steps). + Mutually exclusive with warmup_steps. + warmup_start_factor : float, optional + The factor of start_lr for the initial warmup learning rate. + The warmup learning rate starts from warmup_start_factor * start_lr. + Default is 0.0. """ + # === Step 1. Validate stop_lr and stop_lr_ratio (runtime check) === + has_stop_lr = stop_lr is not None + has_stop_lr_ratio = stop_lr_ratio is not None + + if has_stop_lr and has_stop_lr_ratio: + raise ValueError( + "stop_lr and stop_lr_ratio are mutually exclusive. " + f"Got stop_lr={stop_lr}, stop_lr_ratio={stop_lr_ratio}" + ) + if not has_stop_lr and not has_stop_lr_ratio: + raise ValueError( + "Either stop_lr or stop_lr_ratio must be provided. " + "Got stop_lr=None, stop_lr_ratio=None" + ) + + # === Step 2. Compute stop_lr from stop_lr_ratio if needed === + if stop_lr_ratio is not None: + self.stop_lr = start_lr * stop_lr_ratio + else: + self.stop_lr = stop_lr + + # === Step 3. Validate warmup_steps and warmup_ratio (runtime check) === + has_warmup_steps = warmup_steps != 0 + has_warmup_ratio = warmup_ratio is not None + + if has_warmup_steps and has_warmup_ratio: + raise ValueError( + "warmup_steps and warmup_ratio are mutually exclusive. " + f"Got warmup_steps={warmup_steps}, warmup_ratio={warmup_ratio}" + ) + + # === Step 4. Compute warmup_steps from warmup_ratio if needed === + if warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * num_steps) + else: + self.warmup_steps = warmup_steps + + # === Step 5. Validate step ranges (runtime check) === + if num_steps < 0: + raise ValueError("num_steps must be non-negative") + if self.warmup_steps < 0: + raise ValueError("warmup_steps must be non-negative") + if num_steps > 0 and self.warmup_steps >= num_steps: + raise ValueError("warmup_steps must be smaller than num_steps") + if num_steps == 0 and self.warmup_steps != 0: + raise ValueError("warmup_steps must be 0 when num_steps is 0") + + # === Step 6. Compute warmup_start_lr === + self.warmup_start_lr = warmup_start_factor * start_lr + + # === Step 7. Store core parameters === self.start_lr = start_lr - self.stop_lr = stop_lr - self.stop_steps = stop_steps + self.num_steps = num_steps + # Decay phase covers (num_steps - warmup_steps) steps + self.decay_num_steps = num_steps - self.warmup_steps @abstractmethod - def value(self, step: int | Array) -> Array: - """Get the learning rate at the given step.""" - # in optax, step will be a jnp.ndarray passed in JIT mode + def _decay_value(self, step: int | Array) -> Array: + """ + Get the decayed learning rate at the given step (after warmup). + + This method should implement the actual decay logic (exp, cosine, etc.) + without considering warmup. + + Parameters + ---------- + step : int or Array + The step index relative to the end of warmup. + For example, if warmup_steps=100 and total_step=150, this method + will be called with step=50. + + Returns + ------- + Array + The decayed learning rate (absolute value, not factor). + """ pass + def value(self, step: int | Array) -> Array | float: + """ + Get the learning rate at the given step, including warmup. + + Parameters + ---------- + step : int or Array + The absolute step index from the start of training. + + Returns + ------- + Array + The learning rate at the given step. + """ + is_scalar = isinstance(step, (int, float)) + if not array_api_compat.is_array_api_obj(step): + step = np.asarray(step) + xp = array_api_compat.array_namespace(step) + + # === Step 1. Handle no-warmup case directly === + if self.warmup_steps == 0: + lr = self._decay_value(xp.astype(step, xp.float64)) + else: + # === Step 2. Warmup phase === + # Linear warmup from warmup_start_lr to start_lr + warmup_progress = xp.astype(step, xp.float64) / self.warmup_steps + warmup_lr = ( + self.warmup_start_lr + + (self.start_lr - self.warmup_start_lr) * warmup_progress + ) + + # === Step 3. Decay phase === + # Call subclass decay logic for steps after warmup + decay_step = xp.maximum( + xp.astype(step, xp.float64) - self.warmup_steps, 0.0 + ) + decay_lr = self._decay_value(decay_step) + + # === Step 4. Select warmup or decay based on step === + lr = xp.where(step < self.warmup_steps, warmup_lr, decay_lr) + + if is_scalar: + return float(lr) + return lr + @BaseLR.register("exp") class LearningRateExp(BaseLR): + r""" + Exponential decay learning rate schedule with optional warmup. + + The decay phase (after warmup) follows the exponential decay formula. + + **Stepped mode (smooth=False, default):** + + .. math:: + + lr(t) = lr_0 \cdot r^{\lfloor t / s \rfloor} + + The learning rate decays every ``decay_steps`` steps, creating a staircase + pattern. + + **Smooth mode (smooth=True):** + + .. math:: + + lr(t) = lr_0 \cdot r^{t / s} + + The learning rate decays continuously at every step. + + where: + - :math:`lr_0` is ``start_lr`` (learning rate at the start of decay phase) + - :math:`r` is the decay rate ``decay_rate`` + - :math:`t` is the step index within the decay phase + - :math:`s` is ``decay_steps`` (the decay period) + + The decay rate is automatically computed from ``start_lr`` and ``stop_lr`` + over the total decay steps unless explicitly provided: + + .. math:: + + r = \left(\frac{lr_{\text{stop}}}{lr_0}\right)^{\frac{s}{T}} + + where :math:`T = \text{num\_steps} - \text{warmup\_steps}` is the total + number of decay steps, and :math:`lr_{\text{stop}}` is ``stop_lr``. + """ + def __init__( self, start_lr: float, - stop_lr: float, - decay_steps: int, - stop_steps: int, + stop_lr: float | None = None, + stop_lr_ratio: float | None = None, + decay_steps: int = 5000, + num_steps: int = 100000, decay_rate: float | None = None, + warmup_steps: int = 0, + warmup_ratio: float | None = None, + warmup_start_factor: float = 0.0, + smooth: bool = False, **kwargs: Any, ) -> None: """ - Construct an exponential-decayed learning rate. + Construct an exponential-decayed learning rate with optional warmup. Parameters ---------- - start_lr - The learning rate at the start of the training. - stop_lr + start_lr : float + The learning rate at the start of the training (after warmup). + stop_lr : float, optional The desired learning rate at the end of the training. When decay_rate is explicitly set, this value will serve as - the minimum learning rate during training. In other words, - if the learning rate decays below stop_lr, stop_lr will be applied instead. - decay_steps + the minimum learning rate during training. + Mutually exclusive with stop_lr_ratio. + stop_lr_ratio : float, optional + The ratio of stop_lr to start_lr. + Mutually exclusive with stop_lr. + decay_steps : int The learning rate is decaying every this number of training steps. - stop_steps - The total training steps for learning rate scheduler. - decay_rate + Default is 5000. + num_steps : int + The total training steps (including warmup). + decay_rate : float, optional The decay rate for the learning rate. If provided, the decay rate will be set instead of calculating it through interpolation between start_lr and stop_lr. + warmup_steps : int, optional + The number of steps for learning rate warmup. + Mutually exclusive with warmup_ratio. Default is 0. + warmup_ratio : float, optional + The ratio of warmup steps to total training steps. + Mutually exclusive with warmup_steps. + warmup_start_factor : float, optional + The factor of start_lr for the initial warmup learning rate. + Default is 0.0. + smooth : bool, optional + If True, use smooth exponential decay (lr decays continuously). + If False (default), use stepped decay (lr decays every decay_steps). + Default is False. + + Raises + ------ + ValueError + If both stop_lr and stop_lr_ratio are provided, or neither is provided. + If both warmup_steps and warmup_ratio are provided. + If decay_steps is not positive. """ - super().__init__(start_lr, stop_lr, stop_steps, **kwargs) - default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1 + super().__init__( + start_lr=start_lr, + stop_lr=stop_lr, + stop_lr_ratio=stop_lr_ratio, + num_steps=num_steps, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + warmup_start_factor=warmup_start_factor, + **kwargs, + ) + # === Step 5. Compute decay_rate for exp scheduler === + # Use decay_num_steps (num_steps - warmup_steps) for decay calculation + decay_total = self.decay_num_steps self.decay_steps = decay_steps - if self.decay_steps >= stop_steps: + + if self.decay_steps <= 0: + raise ValueError(f"decay_steps ({self.decay_steps}) must be positive.") + + # Auto-adjust decay_steps if it exceeds decay_total and decay_rate is not provided + if decay_rate is None and self.decay_steps >= decay_total: + # Compute sensible default: cap at 100, but ensure at least 1 for small decay_total + default_ds = 100 if decay_total // 10 > 100 else decay_total // 100 + 1 self.decay_steps = default_ds - self.decay_rate = np.exp( - np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps) - ).item() + + # Avoid log(0) issues by clamping stop_lr for computation + clamped_stop_lr = max(self.stop_lr, 1e-10) + self.min_lr = self.stop_lr + + # Compute decay_rate from start_lr/stop_lr if not explicitly provided if decay_rate is not None: self.decay_rate = decay_rate - self.min_lr = self.stop_lr + elif decay_total == 0: + # No decay phase (num_steps == warmup_steps or num_steps == 0) + self.decay_rate = 1.0 # No decay + else: + self.decay_rate = np.exp( + np.log(clamped_stop_lr / self.start_lr) + / (decay_total / self.decay_steps) + ).item() + + # === Step 6. Store smooth mode === + self.smooth = smooth + + def _decay_value(self, step: int | Array) -> Array: + """ + Get the exponential-decayed learning rate factor at the given step. - def value(self, step: int | Array) -> Array: - """Get the learning rate at the given step.""" + Parameters + ---------- + step : int or Array + The step index relative to the end of warmup. + + Returns + ------- + Array + The decayed learning rate (absolute value). + """ if not array_api_compat.is_array_api_obj(step): step = np.asarray(step) xp = array_api_compat.array_namespace(step) + # === Step 1. Compute exponent based on smooth mode === + if self.smooth: + exponent = xp.astype(step, xp.float64) / self.decay_steps + else: + exponent = xp.astype(step // self.decay_steps, xp.float64) step_lr = self.start_lr * xp.pow( xp.asarray(self.decay_rate, device=array_api_compat.device(step)), - xp.astype(step // self.decay_steps, xp.float64), + exponent, ) - # the original implementation `if step_lr < self.min_lr:` - # will cause a dynamic graph which is unsupported in JAX JIT + # Clip to min_lr for numerical stability in JIT step_lr = xp.clip(step_lr, self.min_lr, None) return step_lr @BaseLR.register("cosine") class LearningRateCosine(BaseLR): + r""" + Cosine annealing learning rate schedule with optional warmup. + + The decay phase (after warmup) follows the cosine annealing formula: + + .. math:: + + lr(t) = lr_{\text{stop}} + \frac{lr_0 - lr_{\text{stop}}}{2} \left(1 + \cos\left(\pi \frac{t}{T}\right)\right) + + where: + - :math:`lr_0` is ``start_lr`` (learning rate at the start of decay phase) + - :math:`lr_{\text{stop}}` is ``stop_lr`` (minimum learning rate) + - :math:`t` is the step index within the decay phase + - :math:`T = \text{num\_steps} - \text{warmup\_steps}` is the total + number of decay steps + + Equivalently, using :math:`\alpha = lr_{\text{stop}} / lr_0`: + + .. math:: + + lr(t) = lr_0 \cdot \left[\alpha + \frac{1}{2}(1 - \alpha) \left(1 + \cos\left(\pi \frac{t}{T}\right)\right)\right] + """ + def __init__( self, start_lr: float, - stop_lr: float, - stop_steps: int, + stop_lr: float | None = None, + stop_lr_ratio: float | None = None, + num_steps: int = 100000, + warmup_steps: int = 0, + warmup_ratio: float | None = None, + warmup_start_factor: float = 0.0, **kwargs: Any, ) -> None: """ - Defines a cosine annealing learning rate schedule. - The learning rate starts at `start_lr` and gradually decreases to `stop_lr` - following a cosine curve over the training steps. + Construct a cosine annealing learning rate schedule with optional warmup. Parameters ---------- - start_lr - The initial learning rate at the beginning of training. - stop_lr + start_lr : float + The learning rate at the start of the training (after warmup). + stop_lr : float, optional The final learning rate at the end of training. - stop_steps - The total number of training steps over which the learning rate - will be annealed from start_lr to stop_lr. + Mutually exclusive with stop_lr_ratio. + stop_lr_ratio : float, optional + The ratio of stop_lr to start_lr. + Mutually exclusive with stop_lr. + num_steps : int + The total training steps (including warmup). + warmup_steps : int, optional + The number of steps for learning rate warmup. + Mutually exclusive with warmup_ratio. Default is 0. + warmup_ratio : float, optional + The ratio of warmup steps to total training steps. + Mutually exclusive with warmup_steps. + warmup_start_factor : float, optional + The factor of start_lr for the initial warmup learning rate. + Default is 0.0. + + Raises + ------ + ValueError + If both stop_lr and stop_lr_ratio are provided, or neither is provided. + If both warmup_steps and warmup_ratio are provided. + """ + super().__init__( + start_lr=start_lr, + stop_lr=stop_lr, + stop_lr_ratio=stop_lr_ratio, + num_steps=num_steps, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + warmup_start_factor=warmup_start_factor, + **kwargs, + ) + self.lr_min_factor = self.stop_lr / self.start_lr + + def _decay_value(self, step: int | Array) -> Array: """ - super().__init__(start_lr, stop_lr, stop_steps, **kwargs) - self.lr_min_factor = stop_lr / start_lr + Get the cosine-annealed learning rate at the given step. - def value(self, step: int | Array) -> Array: + Parameters + ---------- + step : int or Array + The step index relative to the end of warmup. + + Returns + ------- + Array + The annealed learning rate (absolute value). + """ if not array_api_compat.is_array_api_obj(step): step = np.asarray(step) xp = array_api_compat.array_namespace(step) min_lr = self.start_lr * self.lr_min_factor + # Handle decay_num_steps=0 (no training steps) - return start_lr + if self.decay_num_steps == 0: + return xp.full_like(step, self.start_lr, dtype=xp.float64) step_lr = self.start_lr * ( self.lr_min_factor + 0.5 @@ -153,11 +473,12 @@ def value(self, step: int | Array) -> Array: 1 + xp.cos( xp.asarray( - xp.pi * (xp.astype(step, xp.float64) / self.stop_steps), + xp.pi * (xp.astype(step, xp.float64) / self.decay_num_steps), device=array_api_compat.device(step), ) ) ) ) - step_lr = xp.where(step >= self.stop_steps, min_lr, step_lr) + # Clip to min_lr for steps beyond decay_num_steps + step_lr = xp.where(step >= self.decay_num_steps, min_lr, step_lr) return step_lr diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index dd0fbdc94b..3ebb9f24b4 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -239,7 +239,7 @@ def get_sample(): return get_sample def get_lr(lr_params: dict[str, Any]) -> BaseLR: - lr_params["stop_steps"] = self.num_steps - self.warmup_steps + lr_params["num_steps"] = self.num_steps lr_schedule = BaseLR(**lr_params) return lr_schedule @@ -387,11 +387,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) # Learning rate - self.warmup_steps = training_params.get("warmup_steps", 0) self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) - assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, ( - "Warm up steps must be less than total training steps!" - ) if self.multi_task and config.get("learning_rate_dict", None) is not None: self.lr_exp = {} for model_key in self.model_keys: @@ -580,18 +576,13 @@ def single_model_finetune( # TODO add lr warmups for multitask # author: iProzd - def warm_up_linear(step, warmup_steps): - if step < warmup_steps: - return step / warmup_steps - else: - return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr - # TODO add optimizers for multitask # author: iProzd if self.opt_type == "Adam": self.scheduler = paddle.optimizer.lr.LambdaDecay( learning_rate=self.lr_exp.start_lr, - lr_lambda=lambda step: warm_up_linear(step, self.warmup_steps), + lr_lambda=lambda step: self.lr_exp.value(step + self.start_step) + / self.lr_exp.start_lr, ) self.optimizer = paddle.optimizer.Adam( learning_rate=self.scheduler, parameters=self.wrapper.parameters() @@ -755,10 +746,7 @@ def step(_step_id, task_key="Default") -> None: fout1.flush() if self.opt_type == "Adam": cur_lr = self.scheduler.get_lr() - if _step_id < self.warmup_steps: - pref_lr = _lr.start_lr - else: - pref_lr = cur_lr + pref_lr = cur_lr # disable synchronization in forward-backward manually # as derivatives exist in model forward diff --git a/deepmd/pd/utils/utils.py b/deepmd/pd/utils/utils.py index 7224547805..4c4222dd5f 100644 --- a/deepmd/pd/utils/utils.py +++ b/deepmd/pd/utils/utils.py @@ -27,6 +27,7 @@ from .env import ( DEVICE, + GLOBAL_NP_FLOAT_PRECISION, ) from .env import PRECISION_DICT as PD_PRECISION_DICT @@ -239,7 +240,8 @@ def to_numpy_array( ): if xx is None: return None - assert xx is not None + if isinstance(xx, (float, int)): + return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION) # Create a reverse mapping of PD_PRECISION_DICT reverse_precision_dict = {v: k for k, v in PD_PRECISION_DICT.items()} # Use the reverse mapping to find keys with the desired value diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0dfbe94b6b..c4a7b39047 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -273,7 +273,7 @@ def get_sample() -> Any: return get_sample def get_lr(lr_params: dict[str, Any]) -> BaseLR: - lr_params["stop_steps"] = self.num_steps - self.warmup_steps + lr_params["num_steps"] = self.num_steps lr_schedule = BaseLR(**lr_params) return lr_schedule @@ -431,27 +431,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ) # Learning rate - warmup_steps = training_params.get("warmup_steps", None) - warmup_ratio = training_params.get("warmup_ratio", None) - if warmup_steps is not None: - self.warmup_steps = warmup_steps - elif warmup_ratio is not None: - if not 0 <= warmup_ratio < 1: - raise ValueError(f"warmup_ratio must be in [0, 1), got {warmup_ratio}") - self.warmup_steps = int(warmup_ratio * self.num_steps) - if self.warmup_steps == 0 and warmup_ratio > 0: - log.warning( - f"warmup_ratio {warmup_ratio} results in 0 warmup steps " - f"due to truncation. Consider using a larger ratio or " - f"specify warmup_steps directly." - ) - else: - self.warmup_steps = 0 - self.warmup_start_factor = training_params.get("warmup_start_factor", 0.0) self.gradient_max_norm = training_params.get("gradient_max_norm", 0.0) - assert self.num_steps - self.warmup_steps > 0 or self.warmup_steps == 0, ( - "Warm up steps must be less than total training steps!" - ) if self.multi_task and config.get("learning_rate_dict", None) is not None: self.lr_exp = {} for model_key in self.model_keys: @@ -697,14 +677,6 @@ def single_model_finetune( # TODO add lr warmups for multitask # author: iProzd - def warm_up_linear(step: int, warmup_steps: int) -> float: - if step < warmup_steps: - return self.warmup_start_factor + (1.0 - self.warmup_start_factor) * ( - step / warmup_steps - ) - else: - return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr - # TODO add optimizers for multitask # author: iProzd if self.opt_type in ["Adam", "AdamW"]: @@ -725,7 +697,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, - lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + lambda step: self.lr_exp.value(step + self.start_step) + / self.lr_exp.start_lr, ) elif self.opt_type == "LKF": self.optimizer = LKFOptimizer( @@ -748,7 +721,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, - lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + lambda step: self.lr_exp.value(step + self.start_step) + / self.lr_exp.start_lr, ) else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") @@ -822,10 +796,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: fout1.flush() if self.opt_type in ["Adam", "AdamW", "AdaMuon"]: cur_lr = self.scheduler.get_last_lr()[0] - if _step_id < self.warmup_steps: - pref_lr = _lr.start_lr - else: - pref_lr = cur_lr + pref_lr = cur_lr model_pred, loss, more_loss = self.wrapper( **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key ) diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index ab066bdf93..67039c0efa 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -16,6 +16,7 @@ from .env import ( DEVICE, + GLOBAL_NP_FLOAT_PRECISION, ) from .env import PRECISION_DICT as PT_PRECISION_DICT @@ -227,11 +228,14 @@ def to_numpy_array(xx: None) -> None: ... def to_numpy_array( - xx: torch.Tensor | None, + xx: torch.Tensor | np.ndarray | float | None, ) -> np.ndarray | None: if xx is None: return None - assert xx is not None + if isinstance(xx, (float, int)): + return np.array(xx, dtype=GLOBAL_NP_FLOAT_PRECISION) + if isinstance(xx, np.ndarray): + return xx.astype(GLOBAL_NP_FLOAT_PRECISION) # Create a reverse mapping of PT_PRECISION_DICT reverse_precision_dict = {v: k for k, v in PT_PRECISION_DICT.items()} # Use the reverse mapping to find keys with the desired value @@ -239,6 +243,7 @@ def to_numpy_array( prec = NP_PRECISION_DICT.get(prec, None) if prec is None: raise ValueError(f"unknown precision {xx.dtype}") + assert isinstance(xx, torch.Tensor) if xx.dtype == torch.bfloat16: # https://github.com/pytorch/pytorch/issues/109873 xx = xx.float() diff --git a/deepmd/tf/fit/dipole.py b/deepmd/tf/fit/dipole.py index 961198b8e7..ebeec270e0 100644 --- a/deepmd/tf/fit/dipole.py +++ b/deepmd/tf/fit/dipole.py @@ -388,7 +388,7 @@ def get_loss(self, loss: dict, lr) -> Loss: ---------- loss : dict the loss dict - lr : LearningRateExp + lr : LearningRateSchedule the learning rate Returns diff --git a/deepmd/tf/fit/dos.py b/deepmd/tf/fit/dos.py index 250d803d8f..bec8814d18 100644 --- a/deepmd/tf/fit/dos.py +++ b/deepmd/tf/fit/dos.py @@ -655,7 +655,7 @@ def get_loss(self, loss: dict, lr) -> Loss: ---------- loss : dict the loss dict - lr : LearningRateExp + lr : LearningRateSchedule the learning rate Returns diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index 2b8b1b906e..6a027b2ec2 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -856,7 +856,7 @@ def get_loss(self, loss: dict, lr) -> Loss: ---------- loss : dict The loss function parameters. - lr : LearningRateExp + lr : LearningRateSchedule The learning rate. Returns diff --git a/deepmd/tf/fit/fitting.py b/deepmd/tf/fit/fitting.py index b33559f12f..f7e5d959ef 100644 --- a/deepmd/tf/fit/fitting.py +++ b/deepmd/tf/fit/fitting.py @@ -73,7 +73,7 @@ def get_loss(self, loss: dict, lr) -> Loss: ---------- loss : dict the loss dict - lr : LearningRateExp + lr : LearningRateSchedule the learning rate Returns diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 1e48a5fa59..137695d9b8 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -863,7 +863,7 @@ def get_loss(self, loss: dict, lr) -> Loss: ---------- loss : dict the loss dict - lr : LearningRateExp + lr : LearningRateSchedule the learning rate Returns diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 4af59fd290..2ee726cb40 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -4,6 +4,9 @@ import os import shutil import time +from typing import ( + Any, +) import google.protobuf.message import numpy as np @@ -52,7 +55,7 @@ load_graph_def, ) from deepmd.tf.utils.learning_rate import ( - LearningRateExp, + LearningRateSchedule, ) from deepmd.tf.utils.sess import ( run_sess, @@ -100,7 +103,9 @@ def _init_param(self, jdata) -> None: self.model = Model(**model_param) self.fitting = self.model.get_fitting() - def get_lr_and_coef(lr_param): + def get_lr_and_coef( + lr_param: dict[str, Any], + ) -> tuple[LearningRateSchedule, float]: scale_by_worker = lr_param.get("scale_by_worker", "linear") if scale_by_worker == "linear": scale_lr_coef = float(self.run_opt.world_size) @@ -108,13 +113,8 @@ def get_lr_and_coef(lr_param): scale_lr_coef = np.sqrt(self.run_opt.world_size).real else: scale_lr_coef = 1.0 - lr_type = lr_param.get("type", "exp") - if lr_type == "exp": - lr = LearningRateExp( - lr_param["start_lr"], lr_param["stop_lr"], lr_param["decay_steps"] - ) - else: - raise RuntimeError("unknown learning_rate type " + lr_type) + lr_params = {k: v for k, v in lr_param.items() if k != "scale_by_worker"} + lr = LearningRateSchedule(lr_params) return lr, scale_lr_coef # learning rate @@ -242,8 +242,13 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix="") -> Non def _build_lr(self) -> None: self._extra_train_ops = [] self.global_step = tf.train.get_or_create_global_step() - self.learning_rate = self.lr.build(self.global_step, self.stop_batch) - log.info("built lr") + if self.stop_batch == 0: + # Use constant start_lr when stop_batch is zero (no training) + self.learning_rate = tf.cast(self.lr.start_lr(), GLOBAL_TF_FLOAT_PRECISION) + log.info("built lr (constant start_lr for stop_batch=0)") + else: + self.learning_rate = self.lr.build(self.global_step, self.stop_batch) + log.info("built lr") def _build_loss(self): if self.stop_batch == 0: @@ -426,14 +431,21 @@ def train(self, train_data=None, valid_data=None) -> None: elapsed_batch = stop_batch - start_batch is_first_step = True self.cur_batch = cur_batch - log.info( - "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e", - run_sess(self.sess, self.learning_rate), - self.lr.value(cur_batch), - self.lr.decay_steps_, - self.lr.decay_rate_, - self.lr.value(stop_batch), - ) + if stop_batch == 0: + lr0 = self.lr.start_lr() + log.info( + "start training at lr %.2e (== %.2e), final lr will be %.2e", + run_sess(self.sess, self.learning_rate), + lr0, + lr0, + ) + else: + log.info( + "start training at lr %.2e (== %.2e), final lr will be %.2e", + run_sess(self.sess, self.learning_rate), + self.lr.value(cur_batch), + self.lr.value(stop_batch), + ) prf_options = None prf_run_metadata = None @@ -797,7 +809,7 @@ def _get_place_holders(self, data_dict) -> None: prec = GLOBAL_ENER_FLOAT_PRECISION self.place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk) self.place_holders["find_" + kk] = tf.placeholder( - tf.float32, name="t_find_" + kk + GLOBAL_TF_FLOAT_PRECISION, name="t_find_" + kk ) def _init_from_frz_model(self) -> None: diff --git a/deepmd/tf/utils/__init__.py b/deepmd/tf/utils/__init__.py index 7d1e7e67d0..b88c13d445 100644 --- a/deepmd/tf/utils/__init__.py +++ b/deepmd/tf/utils/__init__.py @@ -7,7 +7,7 @@ DeepmdDataSystem, ) from .learning_rate import ( - LearningRateExp, + LearningRateSchedule, ) from .pair_tab import ( PairTab, @@ -20,7 +20,7 @@ __all__ = [ "DeepmdData", "DeepmdDataSystem", - "LearningRateExp", + "LearningRateSchedule", "PairTab", "Plugin", "PluginVariant", diff --git a/deepmd/tf/utils/learning_rate.py b/deepmd/tf/utils/learning_rate.py index 64427e185d..b6111398d0 100644 --- a/deepmd/tf/utils/learning_rate.py +++ b/deepmd/tf/utils/learning_rate.py @@ -1,102 +1,136 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +from typing import ( + Any, +) import numpy as np +from deepmd.dpmodel.utils.learning_rate import ( + BaseLR, +) from deepmd.tf.env import ( + GLOBAL_TF_FLOAT_PRECISION, tf, ) -class LearningRateExp: - r"""The exponentially decaying learning rate. - - The learning rate at step :math:`t` is given by - - .. math:: - - \alpha(t) = \alpha_0 \lambda ^ { t / \tau } +class LearningRateSchedule: + """ + TensorFlow wrapper for BaseLR. - where :math:`\alpha` is the learning rate, :math:`\alpha_0` is the starting learning rate, - :math:`\lambda` is the decay rate, and :math:`\tau` is the decay steps. + The learning rate is computed via :func:`tf.numpy_function`, which prevents + TensorFlow from optimizing this operation in the graph. This overhead is + typically negligible compared to forward/backward passes. Parameters ---------- - start_lr - Starting learning rate :math:`\alpha_0` - stop_lr - Stop learning rate :math:`\alpha_1` - decay_steps - Learning rate decay every this number of steps :math:`\tau` - decay_rate - The decay rate :math:`\lambda`. - If `stop_step` is provided in `build`, then it will be determined automatically and overwritten. + params : dict[str, Any] + Learning rate configuration dictionary. """ - def __init__( - self, - start_lr: float, - stop_lr: float = 5e-8, - decay_steps: int = 5000, - decay_rate: float = 0.95, - ) -> None: - """Constructor.""" - self.cd = {} - self.cd["start_lr"] = start_lr - self.cd["stop_lr"] = stop_lr - self.cd["decay_steps"] = decay_steps - self.cd["decay_rate"] = decay_rate - self.start_lr_ = self.cd["start_lr"] - - def build(self, global_step: tf.Tensor, stop_step: int | None = None) -> tf.Tensor: - """Build the learning rate. + def __init__(self, params: dict[str, Any]) -> None: + self._params = dict(params) + self._base_lr: BaseLR | None = None + + def start_lr(self) -> float: + """ + Get the starting learning rate. + + Returns + ------- + float + The starting learning rate. + """ + return float(self._params["start_lr"]) + + @property + def base_lr(self) -> BaseLR: + """ + Get the built BaseLR instance. + + Returns + ------- + BaseLR + The built learning rate schedule. + + Raises + ------ + RuntimeError + If the schedule has not been built. + """ + if self._base_lr is None: + raise RuntimeError("Learning rate schedule is not built yet.") + return self._base_lr + + def build(self, global_step: tf.Tensor, num_steps: int) -> tf.Tensor: + """ + Build a TensorFlow learning rate tensor. Parameters ---------- - global_step - The tf Tensor providing the global training step - stop_step - The stop step. If provided, the decay_rate will be determined automatically and overwritten. + global_step : tf.Tensor + The global training step tensor. + num_steps : int + The total training steps. Returns ------- - learning_rate - The learning rate + tf.Tensor + The learning rate tensor. """ - if stop_step is None: - self.decay_steps_ = ( - self.cd["decay_steps"] if self.cd["decay_steps"] is not None else 5000 - ) - self.decay_rate_ = ( - self.cd["decay_rate"] if self.cd["decay_rate"] is not None else 0.95 - ) - else: - self.stop_lr_ = ( - self.cd["stop_lr"] if self.cd["stop_lr"] is not None else 5e-8 - ) - default_ds = 100 if stop_step // 10 > 100 else stop_step // 100 + 1 - self.decay_steps_ = ( - self.cd["decay_steps"] - if self.cd["decay_steps"] is not None - else default_ds - ) - if self.decay_steps_ >= stop_step: - self.decay_steps_ = default_ds - self.decay_rate_ = np.exp( - np.log(self.stop_lr_ / self.start_lr_) / (stop_step / self.decay_steps_) + # === Step 1. Instantiate backend-agnostic schedule === + params = dict(self._params) + params["num_steps"] = num_steps + # Default to 'exp' type if not specified + if "type" not in params: + params["type"] = "exp" + self._base_lr = BaseLR(**params) + + # === Step 2. Bind a numpy_function for runtime evaluation === + base_lr = self._base_lr + + def _lr_value(step: np.ndarray) -> np.ndarray: + # Use GLOBAL_TF_FLOAT_PRECISION (float64) for learning rate, + # consistent with energy precision in TF backend + return np.asarray( + base_lr.value(step), + dtype=GLOBAL_TF_FLOAT_PRECISION.as_numpy_dtype, ) - return tf.train.exponential_decay( - self.start_lr_, - global_step, - self.decay_steps_, - self.decay_rate_, - staircase=True, + lr = tf.numpy_function( + _lr_value, [global_step], Tout=GLOBAL_TF_FLOAT_PRECISION, name="lr_schedule" ) - - def start_lr(self) -> float: - """Get the start lr.""" - return self.start_lr_ + lr.set_shape(global_step.get_shape()) + return lr def value(self, step: int) -> float: - """Get the lr at a certain step.""" - return self.start_lr_ * np.power(self.decay_rate_, (step // self.decay_steps_)) + """ + Get the learning rate at the given step. + + Parameters + ---------- + step : int + The step index. + + Returns + ------- + float + The learning rate value. + + Raises + ------ + RuntimeError + If the schedule has not been built. + """ + if self._base_lr is None: + raise RuntimeError("Learning rate schedule is not built yet.") + return float(np.asarray(self._base_lr.value(step))) + + +__all__ = [ + "LearningRateSchedule", +] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 935762cdc7..4dc4573b6d 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2480,55 +2480,239 @@ def linear_ener_model_args() -> Argument: lr_args_plugin = ArgsPlugin() +def _check_lr_stop_args(data: dict[str, Any]) -> bool: + """ + Check that stop_lr and stop_lr_ratio are mutually exclusive and at least one is provided. + + Parameters + ---------- + data : dict[str, Any] + The learning rate configuration dictionary. + + Returns + ------- + bool + True if validation passes. + + Raises + ------ + ValueError + If both stop_lr and stop_lr_ratio are provided, or neither is provided. + """ + has_stop_lr = "stop_lr" in data and data["stop_lr"] is not None + has_stop_lr_ratio = "stop_lr_ratio" in data and data["stop_lr_ratio"] is not None + + if has_stop_lr and has_stop_lr_ratio: + raise ValueError( + "stop_lr and stop_lr_ratio are mutually exclusive. " + f"Got stop_lr={data['stop_lr']}, stop_lr_ratio={data['stop_lr_ratio']}" + ) + if not has_stop_lr and not has_stop_lr_ratio: + raise ValueError( + "Either stop_lr or stop_lr_ratio must be provided. " + "Got stop_lr=None, stop_lr_ratio=None" + ) + return True + + +def _check_warmup_args(data: dict[str, Any]) -> bool: + """ + Check that warmup_steps and warmup_ratio are mutually exclusive. + + Parameters + ---------- + data : dict[str, Any] + The learning rate configuration dictionary. + + Returns + ------- + bool + True if validation passes. + + Raises + ------ + ValueError + If both warmup_steps (non-zero) and warmup_ratio are provided. + """ + # warmup_steps default is 0, so check for non-zero value + has_warmup_steps = "warmup_steps" in data and data["warmup_steps"] != 0 + has_warmup_ratio = "warmup_ratio" in data and data["warmup_ratio"] is not None + + if has_warmup_steps and has_warmup_ratio: + raise ValueError( + "warmup_steps and warmup_ratio are mutually exclusive. " + f"Got warmup_steps={data['warmup_steps']}, warmup_ratio={data['warmup_ratio']}" + ) + return True + + +def _check_decay_steps_args(data: dict[str, Any]) -> bool: + """ + Check that decay_steps is positive and decay_rate is valid for exponential learning rate. + + Parameters + ---------- + data : dict[str, Any] + The learning rate configuration dictionary. + + Returns + ------- + bool + True if validation passes. + + Raises + ------ + ValueError + If decay_steps is not positive. + If decay_rate is not positive. + """ + lr_type = data.get("type", "exp") + if lr_type != "exp": + return True + + decay_steps = data.get("decay_steps") + if decay_steps is not None and decay_steps <= 0: + raise ValueError(f"decay_steps ({decay_steps}) must be positive.") + + decay_rate = data.get("decay_rate") + if decay_rate is not None and (decay_rate <= 0 or decay_rate > 1): + raise ValueError( + f"decay_rate ({decay_rate}) must be in (0, 1] for exponential decay." + ) + return True + + +def _learning_rate_common_args( + doc_stop_lr: str, + extra_args: list[Argument] | None = None, +) -> list[Argument]: + doc_start_lr = "The learning rate at the start of the training (after warmup)." + doc_stop_lr_ratio = ( + "The ratio of stop_lr to start_lr. stop_lr = start_lr * stop_lr_ratio. " + "Mutually exclusive with stop_lr." + ) + doc_warmup_steps = ( + "The number of steps for learning rate warmup. " + "During warmup, the learning rate increases linearly from " + "warmup_start_factor * start_lr to start_lr. " + "Mutually exclusive with warmup_ratio. Default is 0 (no warmup)." + ) + doc_warmup_ratio = ( + "The ratio of warmup steps to total training steps. " + "The actual number of warmup steps is int(warmup_ratio * num_steps)." + "Mutually exclusive with warmup_steps." + ) + doc_warmup_start_factor = ( + "The factor of start_lr for the initial warmup learning rate. " + "The warmup learning rate starts from warmup_start_factor * start_lr. " + "Default is 0.0, meaning the learning rate starts from zero." + ) + + args = [ + Argument("start_lr", float, optional=False, doc=doc_start_lr), + Argument( + "stop_lr", + float, + optional=True, + default=None, + doc=doc_stop_lr, + ), + Argument( + "stop_lr_ratio", + float, + optional=True, + default=None, + doc=doc_stop_lr_ratio, + ), + ] + if extra_args: + args.extend(extra_args) + args.extend( + [ + Argument( + "warmup_steps", + int, + optional=True, + default=0, + doc=doc_warmup_steps, + ), + Argument( + "warmup_ratio", + float, + optional=True, + default=None, + doc=doc_warmup_ratio, + ), + Argument( + "warmup_start_factor", + float, + optional=True, + default=0.0, + doc=doc_warmup_start_factor, + ), + ] + ) + return args + + @lr_args_plugin.register("exp") def learning_rate_exp() -> list[Argument]: - doc_start_lr = "The learning rate at the start of the training." + """ + Defines an exponential-decayed learning rate schedule with optional warmup. + + The learning rate starts at `start_lr` (after warmup) and decays exponentially + to `stop_lr` over the training steps. + """ doc_stop_lr = ( "The desired learning rate at the end of the training. " - f"When decay_rate {doc_only_pt_supported}is explicitly set, " + "When decay_rate is explicitly set, " "this value will serve as the minimum learning rate during training. " - "In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead." + "In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead. " + "Mutually exclusive with stop_lr_ratio." ) doc_decay_steps = ( - "The learning rate is decaying every this number of training steps." + "The learning rate is decaying every this number of training steps. " + "If decay_steps exceeds the decay phase steps (num_steps - warmup_steps) " + "and decay_rate is not provided, it will be automatically adjusted to a " + "sensible default value." ) doc_decay_rate = ( "The decay rate for the learning rate. " "If this is provided, it will be used directly as the decay rate for learning rate " "instead of calculating it through interpolation between start_lr and stop_lr." ) + doc_smooth = ( + "If True, use smooth exponential decay (lr decays continuously). " + "If False (default), use stepped decay (lr decays every decay_steps)." + ) - args = [ - Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr), - Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr), + extra_args = [ Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps), Argument( "decay_rate", float, optional=True, default=None, - doc=doc_only_pt_supported + doc_decay_rate, + doc=doc_decay_rate, ), + Argument("smooth", bool, optional=True, default=False, doc=doc_smooth), ] - return args + return _learning_rate_common_args(doc_stop_lr, extra_args=extra_args) -@lr_args_plugin.register("cosine", doc=doc_only_pt_supported) +@lr_args_plugin.register("cosine") def learning_rate_cosine() -> list[Argument]: """ - Defines a cosine annealing learning rate schedule. + Defines a cosine annealing learning rate schedule with optional warmup. - The learning rate starts at `start_lr` and gradually decreases to `stop_lr` - following a cosine curve over the training steps. + The learning rate starts at `start_lr` (after warmup) and gradually + decreases to `stop_lr` following a cosine curve over the training steps. """ - doc_start_lr = "The learning rate at the start of the training." - doc_stop_lr = "The desired learning rate at the end of the training. " - - args = [ - Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr), - Argument("stop_lr", float, optional=True, default=1e-5, doc=doc_stop_lr), - ] - return args + doc_stop_lr = ( + "The desired learning rate at the end of training. " + "Mutually exclusive with stop_lr_ratio." + ) + return _learning_rate_common_args(doc_stop_lr) def learning_rate_variant_type_args() -> Variant: @@ -2546,6 +2730,17 @@ def learning_rate_variant_type_args() -> Variant: def learning_rate_args(fold_subdoc: bool = False) -> Argument: doc_scale_by_worker = "When parallel training or batch size scaled, how to alter learning rate. Valid values are `linear`(default), `sqrt` or `none`." doc_lr = "The definition of learning rate" + + def _check_lr_args(data: dict[str, Any]) -> bool: + """Check learning rate argument constraints.""" + # Check stop_lr and stop_lr_ratio + _check_lr_stop_args(data) + # Check warmup_steps and warmup_ratio + _check_warmup_args(data) + # Check decay_steps and decay_rate + _check_decay_steps_args(data) + return True + return Argument( "learning_rate", dict, @@ -2562,6 +2757,7 @@ def learning_rate_args(fold_subdoc: bool = False) -> Argument: optional=True, doc=doc_lr, fold_subdoc=fold_subdoc, + extra_check=_check_lr_args, ) @@ -3240,22 +3436,6 @@ def training_args( doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events." - doc_warmup_steps = ( - "The number of steps for learning rate warmup. During warmup, " - "the learning rate begins at zero and progressively increases linearly to `start_lr`, " - "rather than starting directly from `start_lr`" - ) - doc_warmup_ratio = ( - "The ratio of warmup steps to total training steps. " - "The actual number of warmup steps is calculated as `warmup_ratio * numb_steps`. " - "Valid values are in the range [0.0, 1.0). " - "If `warmup_steps` is set, this option will be ignored." - ) - doc_warmup_start_factor = ( - "The factor of start learning rate to the target learning rate during warmup. " - "The warmup learning rate will linearly increase from `warmup_start_factor * start_lr` to `start_lr`. " - "Default is 0.0, meaning the learning rate starts from zero." - ) doc_gradient_max_norm = ( "Clips the gradient norm to a maximum value. " "If the gradient norm exceeds this value, it will be clipped to this limit. " @@ -3363,25 +3543,6 @@ def training_args( Argument( "tensorboard_freq", int, optional=True, default=1, doc=doc_tensorboard_freq ), - Argument( - "warmup_steps", - int, - optional=True, - doc=doc_only_pt_supported + doc_warmup_steps, - ), - Argument( - "warmup_ratio", - float, - optional=True, - doc=doc_only_pt_supported + doc_warmup_ratio, - ), - Argument( - "warmup_start_factor", - float, - optional=True, - default=0.0, - doc=doc_only_pt_supported + doc_warmup_start_factor, - ), Argument( "gradient_max_norm", float, diff --git a/doc/train/training-advanced.md b/doc/train/training-advanced.md index af4b4b31d9..533fb3604f 100644 --- a/doc/train/training-advanced.md +++ b/doc/train/training-advanced.md @@ -6,44 +6,258 @@ In this section, we will take `$deepmd_source_dir/examples/water/se_e2_a/input.j ### Theory -The learning rate $\gamma$ decays exponentially: +The learning rate schedule consists of two phases: an optional warmup phase followed by a decay phase. + +#### Warmup phase (optional) + +During the warmup phase (steps $0 \leq \tau < \tau^{\text{warmup}}$), the learning rate increases linearly from an initial warmup learning rate to the target starting learning rate: + +```math + \gamma(\tau) = \gamma^{\text{warmup}} + \frac{\gamma^0 - \gamma^{\text{warmup}}}{\tau^{\text{warmup}}} \tau, +``` + +where $\gamma^{\text{warmup}} = f^{\text{warmup}} \cdot \gamma^0$ is the initial warmup learning rate, $f^{\text{warmup}} \in [0, 1]$ is the warmup start factor (default 0.0), and $\tau^{\text{warmup}} \in \mathbb{N}$ is the number of warmup steps. + +#### Decay phase + +After the warmup phase (steps $\tau \geq \tau^{\text{warmup}}$), the learning rate decays according to the selected schedule type. + +**Exponential decay (`type: "exp"`):** + +The learning rate decays exponentially: ```math - \gamma(\tau) = \gamma^0 r ^ {\lfloor \tau/s \rfloor}, + \gamma(\tau) = \gamma^0 r ^ {\lfloor (\tau - \tau^{\text{warmup}})/s \rfloor}, ``` -where $\tau \in \mathbb{N}$ is the index of the training step, $\gamma^0 \in \mathbb{R}$ is the learning rate at the first step, and the decay rate $r$ is given by +where $\tau \in \mathbb{N}$ is the index of the training step, $\gamma^0 \in \mathbb{R}$ is the learning rate at the start of the decay phase (i.e., after warmup), and the decay rate $r$ is given by ```math - r = {\left(\frac{\gamma^{\text{stop}}}{\gamma^0}\right )} ^{\frac{s}{\tau^{\text{stop}}}}, + r = {\left(\frac{\gamma^{\text{stop}}}{\gamma^0}\right )} ^{\frac{s}{\tau^{\text{decay}}}}, ``` -where $\tau^{\text{stop}} \in \mathbb{N}$, $\gamma^{\text{stop}} \in \mathbb{R}$, and $s \in \mathbb{N}$ are the stopping step, the stopping learning rate, and the decay steps, respectively, all of which are hyperparameters provided in advance. +where $\tau^{\text{decay}} = \tau^{\text{stop}} - \tau^{\text{warmup}}$ is the number of decay steps, $\tau^{\text{stop}} \in \mathbb{N}$ is the total training steps, $\gamma^{\text{stop}} \in \mathbb{R}$ is the stopping learning rate, and $s \in \mathbb{N}$ is the decay steps. + +**Cosine annealing (`type: "cosine"`):** + +The learning rate follows a cosine annealing schedule: + +```math + \gamma(\tau) = \gamma^{\text{stop}} + \frac{\gamma^0 - \gamma^{\text{stop}}}{2} \left(1 + \cos\left(\frac{\pi (\tau - \tau^{\text{warmup}})}{\tau^{\text{decay}}}\right)\right), +``` + +where the learning rate smoothly decreases from $\gamma^0$ to $\gamma^{\text{stop}}$ following a cosine curve over the decay phase. + +For both schedule types, the stopping learning rate can be specified directly as $\gamma^{\text{stop}}$ or as a ratio: $\gamma^{\text{stop}} = \rho^{\text{stop}} \cdot \gamma^0$, where $\rho^{\text{stop}} \in (0, 1]$ is the stopping learning rate ratio. [^1] [^1]: This section is built upon Jinzhe Zeng, Duo Zhang, Denghui Lu, Pinghui Mo, Zeyu Li, Yixiao Chen, Marián Rynik, Li'ang Huang, Ziyao Li, Shaochen Shi, Yingze Wang, Haotian Ye, Ping Tuo, Jiabin Yang, Ye Ding, Yifan Li, Davide Tisi, Qiyu Zeng, Han Bao, Yu Xia, Jiameng Huang, Koki Muraoka, Yibo Wang, Junhan Chang, Fengbo Yuan, Sigbjørn Løland Bore, Chun Cai, Yinnian Lin, Bo Wang, Jiayan Xu, Jia-Xin Zhu, Chenxing Luo, Yuzhi Zhang, Rhys E. A. Goodall, Wenshuo Liang, Anurag Kumar Singh, Sikai Yao, Jingchao Zhang, Renata Wentzcovitch, Jiequn Han, Jie Liu, Weile Jia, Darrin M. York, Weinan E, Roberto Car, Linfeng Zhang, Han Wang, [J. Chem. Phys. 159, 054801 (2023)](https://doi.org/10.1063/5.0155600) licensed under a [Creative Commons Attribution (CC BY) license](http://creativecommons.org/licenses/by/4.0/). ### Instructions -The {ref}`learning_rate ` section in `input.json` is given as follows +DeePMD-kit supports two types of learning rate schedules: exponential decay (`type: "exp"`) and cosine annealing (`type: "cosine"`). Both types support optional warmup and can use either absolute stopping learning rate or a ratio-based specification. + +#### Exponential decay schedule + +The {ref}`learning_rate ` section for exponential decay in `input.json` is given as follows ```json "learning_rate" :{ - "type": "exp", - "start_lr": 0.001, - "stop_lr": 3.51e-8, - "decay_steps": 5000, - "_comment": "that's all" + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-6, + "decay_steps": 5000, + "_comment": "that's all" + } +``` + +#### Basic parameters + +The following parameters are available for learning rate configuration. + +**Common parameters for both `exp` and `cosine` types:** + +- {ref}`start_lr ` gives the learning rate at the start of the decay phase (i.e., after warmup if enabled). It should be set appropriately based on the model architecture and dataset. +- {ref}`stop_lr ` gives the target learning rate at the end of the training. It should be small enough to ensure that the network parameters satisfactorily converge. This parameter is mutually exclusive with {ref}`stop_lr_ratio `. +- {ref}`stop_lr_ratio ` (optional) specifies the stopping learning rate as a ratio of {ref}`start_lr `. For example, `stop_lr_ratio: 1e-3` means `stop_lr = start_lr * 1e-3`. This parameter is mutually exclusive with {ref}`stop_lr `. Either {ref}`stop_lr ` or {ref}`stop_lr_ratio ` must be provided. + +**Additional parameters for `exp` type only:** + +- {ref}`decay_steps ` specifies the interval (in training steps) at which the learning rate is decayed. The learning rate is updated every {ref}`decay_steps ` steps during the decay phase. If `decay_steps` exceeds the decay phase steps (num_steps - warmup_steps) and `decay_rate` is not explicitly provided, it will be automatically adjusted to a sensible default value. +- {ref}`smooth ` (optional, default: `false`) controls the decay behavior. When set to `false`, the learning rate decays in a stepped manner (updated every `decay_steps` steps). When set to `true`, the learning rate decays smoothly at every step. + +**Learning rate formula for `exp` type:** + +During the decay phase, the learning rate decays exponentially from {ref}`start_lr ` to {ref}`stop_lr `. + +- **Stepped mode (`smooth: false`, default):** + +```text +lr(t) = start_lr * decay_rate ^ floor((t - warmup_steps) / decay_steps) +``` + +- **Smooth mode (`smooth: true`):** + +```text +lr(t) = start_lr * decay_rate ^ ((t - warmup_steps) / decay_steps) +``` + +where `t` is the current training step and `warmup_steps` is the number of warmup steps (0 if warmup is not enabled). + +The formula for cosine annealing is as follows. + +**Learning rate formula for `cosine` type:** + +For cosine annealing, the learning rate smoothly decreases following a cosine curve: + +```text +lr(t) = stop_lr + (start_lr - stop_lr) / 2 * (1 + cos(pi * (t - warmup_steps) / decay_phase_steps)) +``` + +where `decay_phase_steps = numb_steps - warmup_steps` is the number of steps in the decay phase. + +#### Warmup parameters (optional) + +Warmup is a technique to stabilize training in the early stages by gradually increasing the learning rate from a small initial value to the target {ref}`start_lr `. The warmup parameters are optional and can be configured as follows: + +- {ref}`warmup_steps ` (optional, default: 0) specifies the number of steps for learning rate warmup. During warmup, the learning rate increases linearly from `warmup_start_factor * start_lr` to {ref}`start_lr `. This parameter is mutually exclusive with {ref}`warmup_ratio `. +- {ref}`warmup_ratio ` (optional) specifies the warmup duration as a ratio of the total training steps. For example, `warmup_ratio: 0.1` means the warmup phase will last for 10% of the total training steps. The actual number of warmup steps is computed as `int(warmup_ratio * numb_steps)`. This parameter is mutually exclusive with {ref}`warmup_steps `. +- {ref}`warmup_start_factor ` (optional, default: 0.0) specifies the factor for the initial warmup learning rate. The warmup learning rate starts from `warmup_start_factor * start_lr` and increases linearly to {ref}`start_lr `. A value of 0.0 means the learning rate starts from zero. + +#### Configuration examples + +The following examples demonstrate various learning rate configurations. + +**Example 1: Basic exponential decay without warmup** + +```json + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-6, + "decay_steps": 5000 + } +``` + +**Example 2: Using stop_lr_ratio instead of stop_lr** + +```json + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr_ratio": 1e-3, + "decay_steps": 5000 + } +``` + +This is equivalent to setting `stop_lr: 1e-6` (i.e., `0.001 * 1e-3`). + +The following example shows exponential decay with warmup using a specific number of warmup steps. + +**Example 3: Exponential decay with warmup (using warmup_steps)** + +```json + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-6, + "decay_steps": 5000, + "warmup_steps": 10000, + "warmup_start_factor": 0.1 } ``` -- {ref}`start_lr ` gives the learning rate at the beginning of the training. -- {ref}`stop_lr ` gives the learning rate at the end of the training. It should be small enough to ensure that the network parameters satisfactorily converge. -- During the training, the learning rate decays exponentially from {ref}`start_lr ` to {ref}`stop_lr ` following the formula: +In this example, the learning rate starts from `0.0001` (i.e., `0.1 * 0.001`) and increases linearly to `0.001` over the first 10,000 steps. After that, it decays exponentially to `1e-6`. + +The following example shows exponential decay with warmup using a ratio-based warmup duration. + +**Example 4: Exponential decay with warmup (using warmup_ratio)** - ``` - lr(t) = start_lr * decay_rate ^ ( t / decay_steps ) - ``` +```json + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr_ratio": 1e-3, + "decay_steps": 5000, + "warmup_ratio": 0.05 + } +``` + +In this example, if the total training steps (`numb_steps`) is 1,000,000, the warmup phase will last for 50,000 steps (i.e., `0.05 * 1,000,000`). The learning rate starts from `0.0` (default `warmup_start_factor: 0.0`) and increases linearly to `0.001` over the first 50,000 steps, then decays exponentially. + +The following examples demonstrate cosine annealing configurations. + +#### Cosine annealing schedule + +The {ref}`learning_rate ` section for cosine annealing in `input.json` is given as follows + +```json + "learning_rate": { + "type": "cosine", + "start_lr": 0.001, + "stop_lr": 1e-6 + } +``` + +Cosine annealing provides a smooth decay curve that often works well for training neural networks. Unlike exponential decay, it does not require the `decay_steps` parameter. + +The following example shows basic cosine annealing without warmup. + +**Example 5: Basic cosine annealing without warmup** + +```json + "learning_rate": { + "type": "cosine", + "start_lr": 0.001, + "stop_lr": 1e-6 + } +``` + +The following example shows cosine annealing with stop_lr_ratio. + +**Example 6: Cosine annealing with stop_lr_ratio** + +```json + "learning_rate": { + "type": "cosine", + "start_lr": 0.001, + "stop_lr_ratio": 1e-3 + } +``` + +This is equivalent to setting `stop_lr: 1e-6` (i.e., `0.001 * 1e-3`). + +The following example shows cosine annealing with warmup. + +**Example 7: Cosine annealing with warmup** + +```json + "learning_rate": { + "type": "cosine", + "start_lr": 0.001, + "stop_lr": 1e-6, + "warmup_steps": 5000, + "warmup_start_factor": 0.0 + } +``` + +In this example, the learning rate starts from `0.0` and increases linearly to `0.001` over the first 5,000 steps, then follows a cosine annealing curve down to `1e-6`. + +The following example shows exponential decay with smooth mode enabled. + +**Example 8: Exponential decay with smooth mode** + +```json + "learning_rate": { + "type": "exp", + "start_lr": 0.001, + "stop_lr": 1e-6, + "decay_steps": 5000, + "smooth": true + } +``` + +By setting `smooth: true`, the learning rate decays smoothly at every step instead of in a stepped manner. This provides a more gradual decay curve similar to PyTorch's `ExponentialLR`, whereas the default stepped mode (`smooth: false`) is similar to PyTorch's `StepLR`. ## Training parameters @@ -51,25 +265,25 @@ Other training parameters are given in the {ref}`training ` section. ```json "training": { - "training_data": { - "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], - "batch_size": "auto" - }, - "validation_data":{ - "systems": ["../data_water/data_3"], - "batch_size": 1, - "numb_btch": 3 - }, - "mixed_precision": { - "output_prec": "float32", - "compute_prec": "float16" - }, - - "numb_steps": 1000000, - "seed": 1, - "disp_file": "lcurve.out", - "disp_freq": 100, - "save_freq": 1000 + "training_data": { + "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], + "batch_size": "auto" + }, + "validation_data":{ + "systems": ["../data_water/data_3"], + "batch_size": 1, + "numb_btch": 3 + }, + "mixed_precision": { + "output_prec": "float32", + "compute_prec": "float16" + }, + + "numb_steps": 1000000, + "seed": 1, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 1000 } ``` @@ -85,21 +299,21 @@ The sections {ref}`training_data ` and {ref}`validation_ - An example of using `"auto_prob"` is given below. The probability of using `systems[2]` is 0.4, and the sum of the probabilities of using `systems[0]` and `systems[1]` is 0.6. If the number of frames in `systems[1]` is twice of `system[0]`, then the probability of using `system[1]` is 0.4 and that of `system[0]` is 0.2. ```json - "training_data": { - "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], - "auto_prob": "prob_sys_size; 0:2:0.6; 2:3:0.4", - "batch_size": "auto" - } + "training_data": { + "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], + "auto_prob": "prob_sys_size; 0:2:0.6; 2:3:0.4", + "batch_size": "auto" + } ``` - The probability of using systems can also be specified explicitly with key {ref}`sys_probs ` which is a list having the length of the number of systems. For example ```json - "training_data": { - "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], - "sys_probs": [0.5, 0.3, 0.2], - "batch_size": "auto:32" - } + "training_data": { + "systems": ["../data_water/data_0/", "../data_water/data_1/", "../data_water/data_2/"], + "sys_probs": [0.5, 0.3, 0.2], + "batch_size": "auto:32" + } ``` - The key {ref}`batch_size ` specifies the number of frames used to train or validate the model in a training step. It can be set to @@ -158,9 +372,9 @@ One can use `--init-frz-model` features to adjust (increase or decrease) [`sel`] ```json "model": { - "descriptor": { - "sel": [23, 46] - } + "descriptor": { + "sel": [23, 46] + } } ``` @@ -168,7 +382,7 @@ To obtain the new model at once, [`numb_steps`](./train-input.rst) should be set ```json "training": { - "numb_steps": 0 + "numb_steps": 0 } ``` diff --git a/examples/hessian/multi_task/input.json b/examples/hessian/multi_task/input.json index b9a347581b..06ef1f9f48 100644 --- a/examples/hessian/multi_task/input.json +++ b/examples/hessian/multi_task/input.json @@ -2,22 +2,13 @@ "_comment": "that's all", "model": { "shared_dict": { - "type_map_all": [ - "C", - "H", - "N", - "O" - ], + "type_map_all": ["C", "H", "N", "O"], "dpa1_descriptor": { "type": "dpa1", "sel": 120, "rcut_smth": 0.5, "rcut": 6.0, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "tebd_dim": 256, "axis_neuron": 16, "type_one_side": true, @@ -37,11 +28,7 @@ "type_map": "type_map_all", "descriptor": "dpa1_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -51,11 +38,7 @@ "type_map": "type_map_all", "descriptor": "dpa1_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -67,7 +50,7 @@ "type": "exp", "decay_steps": 20000, "start_lr": 0.0002, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss_dict": { @@ -100,25 +83,20 @@ "data_dict": { "H10C5N2O": { "training_data": { - "systems": [ - "../data/H10C5N2O/" - ], + "systems": ["../data/H10C5N2O/"], "batch_size": 1, "_comment": "that's all" } }, "H8C4N2O": { "training_data": { - "systems": [ - "../data/H8C4N2O/" - ], + "systems": ["../data/H8C4N2O/"], "batch_size": 1, "_comment": "that's all" } } }, "numb_steps": 1, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/hessian/single_task/input.json b/examples/hessian/single_task/input.json index 3e61deac52..307767959d 100644 --- a/examples/hessian/single_task/input.json +++ b/examples/hessian/single_task/input.json @@ -1,12 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "C", - "H", - "N", - "O" - ], + "type_map": ["C", "H", "N", "O"], "descriptor": { "type": "dpa2", "repinit": { @@ -14,11 +9,7 @@ "rcut": 6.0, "rcut_smth": 0.5, "nsel": 120, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "axis_neuron": 12, "activation_function": "tanh", "three_body_sel": 48, @@ -57,11 +48,7 @@ "add_tebd_to_repinit_out": false }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float64", "seed": 1, @@ -73,7 +60,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.001, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -91,21 +78,16 @@ "training": { "stat_file": "./hess.hdf5", "training_data": { - "systems": [ - "../data/H8C4N2O" - ], + "systems": ["../data/H8C4N2O"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/H10C5N2O" - ], + "systems": ["../data/H10C5N2O"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/property/train/input_torch.json b/examples/property/train/input_torch.json index 1e6ce00048..24ba87bacf 100644 --- a/examples/property/train/input_torch.json +++ b/examples/property/train/input_torch.json @@ -1,22 +1,13 @@ { "_comment": "that's all", "model": { - "type_map": [ - "H", - "C", - "N", - "O" - ], + "type_map": ["H", "C", "N", "O"], "descriptor": { "type": "dpa1", "sel": 120, "rcut_smth": 0.5, "rcut": 6.0, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "tebd_dim": 8, "axis_neuron": 16, "type_one_side": true, @@ -34,11 +25,7 @@ "intensive": true, "task_dim": 3, "property_name": "band_prop", - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -49,36 +36,28 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.0002, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { "type": "property", - "metric": [ - "mae" - ], + "metric": ["mae"], "loss_func": "smooth_mae", "beta": 1.0, "_comment": " that's all" }, "training": { "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1" - ], + "systems": ["../data/data_0", "../data/data_1"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_2" - ], + "systems": ["../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa2/input_torch_compressible.json b/examples/water/dpa2/input_torch_compressible.json index 14ec347b35..e46fceb1dc 100644 --- a/examples/water/dpa2/input_torch_compressible.json +++ b/examples/water/dpa2/input_torch_compressible.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa2", "repinit": { @@ -12,11 +9,7 @@ "rcut": 6.0, "rcut_smth": 0.5, "nsel": 120, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "axis_neuron": 12, "activation_function": "tanh", "three_body_sel": 48, @@ -57,11 +50,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float64", "seed": 1, @@ -73,7 +62,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.001, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -89,23 +78,16 @@ "training": { "stat_file": "./dpa2.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa2/input_torch_large.json b/examples/water/dpa2/input_torch_large.json index 4894cc6915..0e80e997e2 100644 --- a/examples/water/dpa2/input_torch_large.json +++ b/examples/water/dpa2/input_torch_large.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa2", "repinit": { @@ -12,11 +9,7 @@ "rcut": 6.0, "rcut_smth": 0.5, "nsel": 120, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "axis_neuron": 12, "activation_function": "tanh", "three_body_sel": 48, @@ -56,11 +49,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float64", "seed": 1, @@ -72,7 +61,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.001, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -88,23 +77,16 @@ "training": { "stat_file": "./dpa2.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa2/input_torch_medium.json b/examples/water/dpa2/input_torch_medium.json index b752e28f31..a1e664d897 100644 --- a/examples/water/dpa2/input_torch_medium.json +++ b/examples/water/dpa2/input_torch_medium.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa2", "repinit": { @@ -12,11 +9,7 @@ "rcut": 6.0, "rcut_smth": 0.5, "nsel": 120, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "axis_neuron": 12, "activation_function": "tanh", "three_body_sel": 48, @@ -56,11 +49,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float64", "seed": 1, @@ -72,7 +61,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.001, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -88,23 +77,16 @@ "training": { "stat_file": "./dpa2.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa2/input_torch_small.json b/examples/water/dpa2/input_torch_small.json index bd136a8666..abb74e7e06 100644 --- a/examples/water/dpa2/input_torch_small.json +++ b/examples/water/dpa2/input_torch_small.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa2", "repinit": { @@ -12,11 +9,7 @@ "rcut": 6.0, "rcut_smth": 0.5, "nsel": 120, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "axis_neuron": 12, "activation_function": "tanh", "three_body_sel": 48, @@ -56,11 +49,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float64", "seed": 1, @@ -72,7 +61,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.001, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -88,23 +77,16 @@ "training": { "stat_file": "./dpa2.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa3/input_torch.json b/examples/water/dpa3/input_torch.json index ec8bba4821..791f33305b 100644 --- a/examples/water/dpa3/input_torch.json +++ b/examples/water/dpa3/input_torch.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa3", "repflow": { @@ -38,11 +35,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float32", "activation_function": "silut:10.0", @@ -71,23 +64,16 @@ "training": { "stat_file": "./dpa3.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/dpa3/input_torch_dynamic.json b/examples/water/dpa3/input_torch_dynamic.json index b3137feffc..a2af9aa4db 100644 --- a/examples/water/dpa3/input_torch_dynamic.json +++ b/examples/water/dpa3/input_torch_dynamic.json @@ -1,10 +1,7 @@ { "_comment": "that's all", "model": { - "type_map": [ - "O", - "H" - ], + "type_map": ["O", "H"], "descriptor": { "type": "dpa3", "repflow": { @@ -40,11 +37,7 @@ "seed": 1 }, "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "precision": "float32", "activation_function": "silut:10.0", @@ -73,23 +66,16 @@ "training": { "stat_file": "./dpa3.hdf5", "training_data": { - "systems": [ - "../data/data_0", - "../data/data_1", - "../data/data_2" - ], + "systems": ["../data/data_0", "../data/data_1", "../data/data_2"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "../data/data_3" - ], + "systems": ["../data/data_3"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water_tensor/dipole/dipole_input.json b/examples/water_tensor/dipole/dipole_input.json index 3feb1fbbc0..e226930328 100644 --- a/examples/water_tensor/dipole/dipole_input.json +++ b/examples/water_tensor/dipole/dipole_input.json @@ -45,6 +45,7 @@ "learning_rate": { "type": "exp", "start_lr": 0.01, + "stop_lr": 1e-7, "decay_steps": 5000, "_comment5": "that's all" }, diff --git a/examples/water_tensor/dipole/dipole_input_torch.json b/examples/water_tensor/dipole/dipole_input_torch.json index f6903d3334..baae0ce90a 100644 --- a/examples/water_tensor/dipole/dipole_input_torch.json +++ b/examples/water_tensor/dipole/dipole_input_torch.json @@ -45,6 +45,7 @@ "learning_rate": { "type": "exp", "start_lr": 0.01, + "stop_lr": 1e-7, "decay_steps": 5000, "_comment5": "that's all" }, diff --git a/source/tests/consistent/test_learning_rate.py b/source/tests/consistent/test_learning_rate.py index 5767f3165e..1c542a199b 100644 --- a/source/tests/consistent/test_learning_rate.py +++ b/source/tests/consistent/test_learning_rate.py @@ -41,34 +41,50 @@ "start_lr": 1e-3, "stop_lr": 1e-8, "decay_steps": 1000, - "stop_steps": 1000000, + "num_steps": 1000000, + "warmup_steps": 10000, }, { "type": "cosine", "start_lr": 1e-3, "stop_lr": 1e-8, - "decay_steps": 1000, - "stop_steps": 1000000, + "num_steps": 1000000, + "warmup_steps": 10000, }, ), ) class TestLearningRateConsistent(unittest.TestCase): + """Test learning rate consistency across different array backends.""" + def setUp(self) -> None: (lr_param,) = self.param self.lr = BaseLR(**lr_param) self.step = 500000 self.ref = self.lr.value(self.step) + self.warmup_step = None + self.warmup_ref = None + if self.lr.warmup_steps > 0: + self.warmup_step = self.lr.warmup_steps // 2 + self.warmup_ref = self.lr.value(self.warmup_step) def compare_test_with_ref(self, step: Array) -> None: test = self.lr.value(step) np.testing.assert_allclose(self.ref, to_numpy_array(test), atol=1e-10) + def compare_test_with_warmup_ref(self, step: Array) -> None: + if self.warmup_ref is None: + self.skipTest("warmup not enabled") + test = self.lr.value(step) + np.testing.assert_allclose(self.warmup_ref, to_numpy_array(test), atol=1e-10) + def compare_numpy_with_ref(self, step: Array) -> None: self.compare_test_with_ref(np.asarray(step)) @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") def test_pt_consistent_with_ref(self) -> None: self.compare_test_with_ref(to_torch_tensor(self.step)) + if self.warmup_step is not None: + self.compare_test_with_warmup_ref(to_torch_tensor(self.warmup_step)) @unittest.skipUnless( INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed" @@ -78,7 +94,11 @@ def test_pt_consistent_with_ref(self) -> None: ) def test_array_api_strict(self) -> None: self.compare_test_with_ref(xp.asarray(self.step)) + if self.warmup_step is not None: + self.compare_test_with_warmup_ref(xp.asarray(self.warmup_step)) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_consistent_with_ref(self) -> None: self.compare_test_with_ref(jnp.array(self.step)) + if self.warmup_step is not None: + self.compare_test_with_warmup_ref(jnp.array(self.warmup_step)) diff --git a/source/tests/pd/model/test_model.py b/source/tests/pd/model/test_model.py index e619171e44..6d22c50e62 100644 --- a/source/tests/pd/model/test_model.py +++ b/source/tests/pd/model/test_model.py @@ -49,7 +49,7 @@ DeepmdDataSystem, ) from deepmd.tf.utils.learning_rate import ( - LearningRateExp, + LearningRateSchedule, ) from ..test_finetune import ( @@ -108,7 +108,7 @@ def __init__(self) -> None: self.start_lr = 0.001 self.stop_lr = 3.51e-8 self.decay_steps = 500 - self.stop_steps = 1600 + self.num_steps = 1600 self.start_pref_e = 1.0 self.limit_pref_e = 2.0 self.start_pref_f = 2.0 @@ -137,7 +137,7 @@ def get_intermediate_state(self, num_steps=1): input_dict=place_holders, ) global_step = tf.train.get_or_create_global_step() - learning_rate = dp_lr.build(global_step, self.stop_steps) + learning_rate = dp_lr.build(global_step, self.num_steps) l2_l, _ = dp_loss.build( learning_rate=learning_rate, natoms=place_holders["natoms_vec"], @@ -226,8 +226,13 @@ def _get_dp_loss(self): ) def _get_dp_lr(self): - return LearningRateExp( - start_lr=self.start_lr, stop_lr=self.stop_lr, decay_steps=self.decay_steps + return LearningRateSchedule( + { + "type": "exp", + "start_lr": self.start_lr, + "stop_lr": self.stop_lr, + "decay_steps": self.decay_steps, + } ) def _get_dp_placeholders(self, dataset): @@ -239,7 +244,7 @@ def _get_dp_placeholders(self, dataset): prec = tf.float64 place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk) place_holders["find_" + kk] = tf.placeholder( - tf.float32, name="t_find_" + kk + tf.float64, name="t_find_" + kk ) place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type") place_holders["natoms_vec"] = tf.placeholder( @@ -298,7 +303,12 @@ def test_consistency(self) -> None: }, ) my_model.to(DEVICE) - my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.stop_steps) + my_lr = MyLRExp( + self.start_lr, + self.stop_lr, + decay_steps=self.decay_steps, + num_steps=self.num_steps, + ) my_loss = EnergyStdLoss( starter_learning_rate=self.start_lr, start_pref_e=self.start_pref_e, diff --git a/source/tests/pd/model/water/multitask.json b/source/tests/pd/model/water/multitask.json index 2786afca59..38da01fac5 100644 --- a/source/tests/pd/model/water/multitask.json +++ b/source/tests/pd/model/water/multitask.json @@ -1,25 +1,13 @@ { "model": { "shared_dict": { - "my_type_map": [ - "O", - "H", - "B" - ], + "my_type_map": ["O", "H", "B"], "my_descriptor": { "type": "se_e2_a", - "sel": [ - 46, - 92, - 4 - ], - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], + "sel": [46, 92, 4], + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [25, 50, 100], "resnet_dt": false, "axis_neuron": 16, "seed": 1, @@ -32,11 +20,7 @@ "type_map": "my_type_map", "descriptor": "my_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -47,11 +31,7 @@ "type_map": "my_type_map", "descriptor": "my_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -65,7 +45,7 @@ "decay_steps": 5000, "start_lr": 0.0002, "decay_rate": 0.98, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss_dict": { @@ -97,16 +77,12 @@ "model_1": { "stat_file": "./stat_files/model_1.hdf5", "training_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } @@ -114,23 +90,18 @@ "model_2": { "stat_file": "./stat_files/model_2.hdf5", "training_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } } }, "numb_steps": 100000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/source/tests/pd/model/water/multitask_sharefit.json b/source/tests/pd/model/water/multitask_sharefit.json index 934ef04998..d2d04df199 100644 --- a/source/tests/pd/model/water/multitask_sharefit.json +++ b/source/tests/pd/model/water/multitask_sharefit.json @@ -1,25 +1,13 @@ { "model": { "shared_dict": { - "my_type_map": [ - "O", - "H", - "B" - ], + "my_type_map": ["O", "H", "B"], "my_descriptor": { "type": "se_e2_a", - "sel": [ - 46, - 92, - 4 - ], - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], + "sel": [46, 92, 4], + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [25, 50, 100], "resnet_dt": false, "axis_neuron": 16, "seed": 1, @@ -27,11 +15,7 @@ }, "my_fitting": { "dim_case_embd": 2, - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -58,7 +42,7 @@ "decay_steps": 5000, "start_lr": 0.0002, "decay_rate": 0.98, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss_dict": { @@ -90,16 +74,12 @@ "model_1": { "stat_file": "./stat_files/model_1.hdf5", "training_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } @@ -107,23 +87,18 @@ "model_2": { "stat_file": "./stat_files/model_2.hdf5", "training_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pd/water/data/data_0" - ], + "systems": ["pd/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } } }, "numb_steps": 100000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/source/tests/pd/test_lr.py b/source/tests/pd/test_lr.py index 9607f982fd..0cc054dce6 100644 --- a/source/tests/pd/test_lr.py +++ b/source/tests/pd/test_lr.py @@ -9,8 +9,8 @@ from deepmd.dpmodel.utils.learning_rate import ( LearningRateExp, ) -from deepmd.tf.utils import ( - learning_rate, +from deepmd.tf.utils.learning_rate import ( + LearningRateSchedule, ) @@ -18,20 +18,26 @@ class TestLearningRate(unittest.TestCase): def setUp(self): self.start_lr = 0.001 self.stop_lr = 3.51e-8 - self.decay_steps = np.arange(400, 601, 100) - self.stop_steps = np.arange(500, 1600, 500) + # decay_steps will be auto-adjusted if >= num_steps + self.decay_steps = np.arange(400, 501, 100) + self.num_steps = np.arange(500, 1600, 500) def test_consistency(self): for decay_step in self.decay_steps: - for stop_step in self.stop_steps: + for stop_step in self.num_steps: self.decay_step = decay_step self.stop_step = stop_step self.judge_it() self.decay_rate_pd() def judge_it(self): - base_lr = learning_rate.LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step + base_lr = LearningRateSchedule( + { + "type": "exp", + "start_lr": self.start_lr, + "stop_lr": self.stop_lr, + "decay_steps": self.decay_step, + } ) g = tf.Graph() with g.as_default(): @@ -39,7 +45,10 @@ def judge_it(self): t_lr = base_lr.build(global_step, self.stop_step) my_lr = LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step, self.stop_step + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_step, + num_steps=self.stop_step, ) with tf.Session(graph=g) as sess: base_vals = [ @@ -57,44 +66,46 @@ def judge_it(self): def decay_rate_pd(self): my_lr = LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step, self.stop_step + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_step, + num_steps=self.stop_step, ) - default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1 - if self.decay_step >= self.stop_step: - self.decay_step = default_ds + # Use the auto-adjusted decay_steps from my_lr for consistency + actual_decay_steps = my_lr.decay_steps decay_rate = np.exp( - np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step) + np.log(self.stop_lr / self.start_lr) / (self.stop_step / actual_decay_steps) ) my_lr_decay = LearningRateExp( - self.start_lr, - 1e-10, - self.decay_step, - self.stop_step, + start_lr=self.start_lr, + stop_lr=1e-10, + decay_steps=actual_decay_steps, + num_steps=self.stop_step, decay_rate=decay_rate, ) min_lr = 1e-5 my_lr_decay_trunc = LearningRateExp( - self.start_lr, - min_lr, - self.decay_step, - self.stop_step, + start_lr=self.start_lr, + stop_lr=min_lr, + decay_steps=actual_decay_steps, + num_steps=self.stop_step, decay_rate=decay_rate, ) my_vals = [ my_lr.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] my_vals_decay = [ my_lr_decay.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] my_vals_decay_trunc = [ my_lr_decay_trunc.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] self.assertTrue(np.allclose(my_vals_decay, my_vals)) self.assertTrue( diff --git a/source/tests/pt/model/test_model.py b/source/tests/pt/model/test_model.py index eee0e9beef..0b39279142 100644 --- a/source/tests/pt/model/test_model.py +++ b/source/tests/pt/model/test_model.py @@ -49,7 +49,7 @@ DeepmdDataSystem, ) from deepmd.tf.utils.learning_rate import ( - LearningRateExp, + LearningRateSchedule, ) from ..test_finetune import ( @@ -108,7 +108,7 @@ def __init__(self) -> None: self.start_lr = 0.001 self.stop_lr = 3.51e-8 self.decay_steps = 500 - self.stop_steps = 1600 + self.num_steps = 1600 self.start_pref_e = 1.0 self.limit_pref_e = 2.0 self.start_pref_f = 2.0 @@ -137,7 +137,7 @@ def get_intermediate_state(self, num_steps=1): input_dict=place_holders, ) global_step = tf.train.get_or_create_global_step() - learning_rate = dp_lr.build(global_step, self.stop_steps) + learning_rate = dp_lr.build(global_step, self.num_steps) l2_l, _ = dp_loss.build( learning_rate=learning_rate, natoms=place_holders["natoms_vec"], @@ -226,8 +226,13 @@ def _get_dp_loss(self): ) def _get_dp_lr(self): - return LearningRateExp( - start_lr=self.start_lr, stop_lr=self.stop_lr, decay_steps=self.decay_steps + return LearningRateSchedule( + { + "type": "exp", + "start_lr": self.start_lr, + "stop_lr": self.stop_lr, + "decay_steps": self.decay_steps, + } ) def _get_dp_placeholders(self, dataset): @@ -239,7 +244,7 @@ def _get_dp_placeholders(self, dataset): prec = tf.float64 place_holders[kk] = tf.placeholder(prec, [None], name="t_" + kk) place_holders["find_" + kk] = tf.placeholder( - tf.float32, name="t_find_" + kk + tf.float64, name="t_find_" + kk ) place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type") place_holders["natoms_vec"] = tf.placeholder( @@ -298,7 +303,12 @@ def test_consistency(self) -> None: }, ) my_model.to(DEVICE) - my_lr = MyLRExp(self.start_lr, self.stop_lr, self.decay_steps, self.stop_steps) + my_lr = MyLRExp( + self.start_lr, + self.stop_lr, + decay_steps=self.decay_steps, + num_steps=self.num_steps, + ) my_loss = EnergyStdLoss( starter_learning_rate=self.start_lr, start_pref_e=self.start_pref_e, diff --git a/source/tests/pt/model/water/multitask.json b/source/tests/pt/model/water/multitask.json index e8d998e6f1..b412f7f2c4 100644 --- a/source/tests/pt/model/water/multitask.json +++ b/source/tests/pt/model/water/multitask.json @@ -1,25 +1,13 @@ { "model": { "shared_dict": { - "my_type_map": [ - "O", - "H", - "B" - ], + "my_type_map": ["O", "H", "B"], "my_descriptor": { "type": "se_e2_a", - "sel": [ - 46, - 92, - 4 - ], - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], + "sel": [46, 92, 4], + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [25, 50, 100], "resnet_dt": false, "axis_neuron": 16, "seed": 1, @@ -32,11 +20,7 @@ "type_map": "my_type_map", "descriptor": "my_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -47,11 +31,7 @@ "type_map": "my_type_map", "descriptor": "my_descriptor", "fitting_net": { - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -65,7 +45,7 @@ "decay_steps": 5000, "start_lr": 0.0002, "decay_rate": 0.98, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss_dict": { @@ -97,16 +77,12 @@ "model_1": { "stat_file": "./stat_files/model_1.hdf5", "training_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } @@ -114,23 +90,18 @@ "model_2": { "stat_file": "./stat_files/model_2.hdf5", "training_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } } }, "numb_steps": 100000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/source/tests/pt/model/water/multitask_sharefit.json b/source/tests/pt/model/water/multitask_sharefit.json index 246b5992f7..2f53740e2d 100644 --- a/source/tests/pt/model/water/multitask_sharefit.json +++ b/source/tests/pt/model/water/multitask_sharefit.json @@ -1,25 +1,13 @@ { "model": { "shared_dict": { - "my_type_map": [ - "O", - "H", - "B" - ], + "my_type_map": ["O", "H", "B"], "my_descriptor": { "type": "se_e2_a", - "sel": [ - 46, - 92, - 4 - ], - "rcut_smth": 0.50, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], + "sel": [46, 92, 4], + "rcut_smth": 0.5, + "rcut": 6.0, + "neuron": [25, 50, 100], "resnet_dt": false, "axis_neuron": 16, "seed": 1, @@ -27,11 +15,7 @@ }, "my_fitting": { "dim_case_embd": 2, - "neuron": [ - 240, - 240, - 240 - ], + "neuron": [240, 240, 240], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -58,7 +42,7 @@ "decay_steps": 5000, "start_lr": 0.0002, "decay_rate": 0.98, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss_dict": { @@ -90,16 +74,12 @@ "model_1": { "stat_file": "./stat_files/model_1.hdf5", "training_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } @@ -107,23 +87,18 @@ "model_2": { "stat_file": "./stat_files/model_2.hdf5", "training_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pt/water/data/data_0" - ], + "systems": ["pt/water/data/data_0"], "batch_size": 1, "_comment": "that's all" } } }, "numb_steps": 100000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/source/tests/pt/property/input.json b/source/tests/pt/property/input.json index 44bc1e6005..1d6a4172cf 100644 --- a/source/tests/pt/property/input.json +++ b/source/tests/pt/property/input.json @@ -1,24 +1,13 @@ { "_comment": "that's all", "model": { - "type_map": [ - "H", - "C", - "N", - "O" - ], + "type_map": ["H", "C", "N", "O"], "descriptor": { "type": "se_e2_a", - "sel": [ - 90 - ], + "sel": [90], "rcut_smth": 1.8, "rcut": 6.0, - "neuron": [ - 25, - 50, - 100 - ], + "neuron": [25, 50, 100], "resnet_dt": false, "axis_neuron": 8, "precision": "float64", @@ -29,11 +18,7 @@ "intensive": true, "property_name": "band_property", "task_dim": 3, - "neuron": [ - 100, - 100, - 100 - ], + "neuron": [100, 100, 100], "resnet_dt": true, "seed": 1, "_comment": " that's all" @@ -44,7 +29,7 @@ "type": "exp", "decay_steps": 5000, "start_lr": 0.0002, - "stop_lr": 3.51e-08, + "stop_lr": 3.51e-8, "_comment": "that's all" }, "loss": { @@ -53,21 +38,16 @@ }, "training": { "training_data": { - "systems": [ - "pt/property/single" - ], + "systems": ["pt/property/single"], "batch_size": 1, "_comment": "that's all" }, "validation_data": { - "systems": [ - "pt/property/single" - ], + "systems": ["pt/property/single"], "batch_size": 1, "_comment": "that's all" }, "numb_steps": 1000000, - "warmup_steps": 0, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py index 75f663f041..9516c056de 100644 --- a/source/tests/pt/test_lr.py +++ b/source/tests/pt/test_lr.py @@ -10,8 +10,8 @@ LearningRateCosine, LearningRateExp, ) -from deepmd.tf.utils import ( - learning_rate, +from deepmd.tf.utils.learning_rate import ( + LearningRateSchedule, ) @@ -19,20 +19,26 @@ class TestLearningRate(unittest.TestCase): def setUp(self) -> None: self.start_lr = 0.001 self.stop_lr = 3.51e-8 - self.decay_steps = np.arange(400, 601, 100) - self.stop_steps = np.arange(500, 1600, 500) + # decay_steps will be auto-adjusted if >= num_steps + self.decay_steps = np.arange(400, 501, 100) + self.num_steps = np.arange(500, 1600, 500) def test_consistency(self) -> None: for decay_step in self.decay_steps: - for stop_step in self.stop_steps: + for stop_step in self.num_steps: self.decay_step = decay_step self.stop_step = stop_step self.judge_it() self.decay_rate_pt() def judge_it(self) -> None: - base_lr = learning_rate.LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step + base_lr = LearningRateSchedule( + { + "type": "exp", + "start_lr": self.start_lr, + "stop_lr": self.stop_lr, + "decay_steps": self.decay_step, + } ) g = tf.Graph() with g.as_default(): @@ -40,7 +46,10 @@ def judge_it(self) -> None: t_lr = base_lr.build(global_step, self.stop_step) my_lr = LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step, self.stop_step + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_step, + num_steps=self.stop_step, ) with tf.Session(graph=g) as sess: base_vals = [ @@ -58,44 +67,46 @@ def judge_it(self) -> None: def decay_rate_pt(self) -> None: my_lr = LearningRateExp( - self.start_lr, self.stop_lr, self.decay_step, self.stop_step + start_lr=self.start_lr, + stop_lr=self.stop_lr, + decay_steps=self.decay_step, + num_steps=self.stop_step, ) - default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1 - if self.decay_step >= self.stop_step: - self.decay_step = default_ds + # Use the auto-adjusted decay_steps from my_lr for consistency + actual_decay_steps = my_lr.decay_steps decay_rate = np.exp( - np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step) + np.log(self.stop_lr / self.start_lr) / (self.stop_step / actual_decay_steps) ) my_lr_decay = LearningRateExp( - self.start_lr, - 1e-10, - self.decay_step, - self.stop_step, + start_lr=self.start_lr, + stop_lr=1e-10, + decay_steps=actual_decay_steps, + num_steps=self.stop_step, decay_rate=decay_rate, ) min_lr = 1e-5 my_lr_decay_trunc = LearningRateExp( - self.start_lr, - min_lr, - self.decay_step, - self.stop_step, + start_lr=self.start_lr, + stop_lr=min_lr, + decay_steps=actual_decay_steps, + num_steps=self.stop_step, decay_rate=decay_rate, ) my_vals = [ my_lr.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] my_vals_decay = [ my_lr_decay.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] my_vals_decay_trunc = [ my_lr_decay_trunc.value(step_id) for step_id in range(self.stop_step) - if step_id % self.decay_step != 0 + if step_id % actual_decay_steps != 0 ] self.assertTrue(np.allclose(my_vals_decay, my_vals)) self.assertTrue( @@ -107,14 +118,18 @@ class TestLearningRateCosine(unittest.TestCase): def test_basic_curve(self) -> None: start_lr = 1.0 stop_lr = 0.1 - stop_steps = 10 - lr = LearningRateCosine(start_lr, stop_lr, stop_steps) + num_steps = 10 + lr = LearningRateCosine( + start_lr=start_lr, + stop_lr=stop_lr, + num_steps=num_steps, + ) self.assertTrue(np.allclose(lr.value(0), start_lr)) - self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr)) - self.assertTrue(np.allclose(lr.value(stop_steps + 5), stop_lr)) + self.assertTrue(np.allclose(lr.value(num_steps), stop_lr)) + self.assertTrue(np.allclose(lr.value(num_steps + 5), stop_lr)) - mid_step = stop_steps // 2 + mid_step = num_steps // 2 expected_mid = stop_lr + (start_lr - stop_lr) * 0.5 self.assertTrue(np.allclose(lr.value(mid_step), expected_mid)) diff --git a/source/tests/tf/test_lr.py b/source/tests/tf/test_lr.py new file mode 100644 index 0000000000..2ddc8edb27 --- /dev/null +++ b/source/tests/tf/test_lr.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for TensorFlow learning rate schedule wrapper. + +This module tests the TF-specific wrapper logic only. +Core learning rate algorithms are tested in dpmodel tests. +""" + +import unittest + +import numpy as np + +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, +) +from deepmd.tf.env import ( + GLOBAL_TF_FLOAT_PRECISION, + tf, +) +from deepmd.tf.utils.learning_rate import ( + LearningRateSchedule, +) + + +class TestLearningRateScheduleValidation(unittest.TestCase): + """Test TF wrapper validation and error handling.""" + + def test_value_before_build(self) -> None: + """Test that calling value() before build() raises RuntimeError.""" + lr_schedule = LearningRateSchedule({"start_lr": 1e-3}) + with self.assertRaises(RuntimeError) as cm: + lr_schedule.value(100) + self.assertIn("not built", str(cm.exception)) + + def test_base_lr_before_build(self) -> None: + """Test that accessing base_lr before build() raises RuntimeError.""" + lr_schedule = LearningRateSchedule({"start_lr": 1e-3}) + with self.assertRaises(RuntimeError) as cm: + _ = lr_schedule.base_lr + self.assertIn("not built", str(cm.exception)) + + +class TestLearningRateScheduleBuild(unittest.TestCase): + """Test TF tensor building and integration.""" + + def test_build_returns_tensor(self) -> None: + """Test that build() returns a TF tensor with correct dtype.""" + lr_schedule = LearningRateSchedule({"start_lr": 1e-3, "stop_lr": 1e-5}) + global_step = tf.constant(0, dtype=tf.int64) + lr_tensor = lr_schedule.build(global_step, num_steps=10000) + + self.assertIsInstance(lr_tensor, tf.Tensor) + self.assertEqual(lr_tensor.dtype, GLOBAL_TF_FLOAT_PRECISION) + + def test_default_type_exp(self) -> None: + """Test that default type is 'exp' when not specified.""" + lr_schedule = LearningRateSchedule({"start_lr": 1e-3, "stop_lr": 1e-5}) + global_step = tf.constant(0, dtype=tf.int64) + lr_schedule.build(global_step, num_steps=10000) + + self.assertIsInstance(lr_schedule.base_lr, LearningRateExp) + + def test_value_method_matches_base_lr(self) -> None: + """Test that value() method matches BaseLR.value() after build.""" + lr_schedule = LearningRateSchedule( + { + "start_lr": 1e-3, + "stop_lr": 1e-5, + "type": "exp", + "decay_steps": 1000, + } + ) + test_step = 5000 + global_step = tf.constant(test_step, dtype=tf.int64) + lr_schedule.build(global_step, num_steps=10000) + + # value() method returns base_lr.value() as float + method_value = lr_schedule.value(test_step) + base_lr_value = lr_schedule.base_lr.value(test_step) + + np.testing.assert_allclose(method_value, base_lr_value, rtol=1e-10) + + def test_start_lr_accessor(self) -> None: + """Test start_lr() accessor returns correct value.""" + lr_schedule = LearningRateSchedule({"start_lr": 1e-3}) + self.assertEqual(lr_schedule.start_lr(), 1e-3) + + def test_value_after_build(self) -> None: + """Test value() works correctly after build().""" + lr_schedule = LearningRateSchedule( + { + "start_lr": 1e-3, + "stop_lr": 1e-5, + "type": "exp", + "decay_steps": 1000, + } + ) + global_step = tf.constant(0, dtype=tf.int64) + lr_schedule.build(global_step, num_steps=10000) + + # value() should work after build + lr_value = lr_schedule.value(5000) + expected = lr_schedule.base_lr.value(5000) + + np.testing.assert_allclose(lr_value, expected, rtol=1e-10) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/universal/dpmodel/utils/test_learning_rate.py b/source/tests/universal/dpmodel/utils/test_learning_rate.py new file mode 100644 index 0000000000..17d6d48d2e --- /dev/null +++ b/source/tests/universal/dpmodel/utils/test_learning_rate.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.utils.learning_rate import ( + LearningRateCosine, + LearningRateExp, +) + + +class TestLearningRateExpBasic(unittest.TestCase): + """Test basic exponential decay learning rate functionality.""" + + def test_basic_decay(self) -> None: + """Test basic exponential decay without warmup.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=5000, + ) + np.testing.assert_allclose(lr.value(0), 1e-3, rtol=1e-10) + np.testing.assert_allclose(lr.value(10000), 1e-5, rtol=1e-5) + + def test_stop_lr_ratio(self) -> None: + """Test stop_lr_ratio parameter.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr_ratio=0.01, + num_steps=10000, + decay_steps=5000, + ) + np.testing.assert_allclose(lr.stop_lr, 1e-5, rtol=1e-10) + np.testing.assert_allclose(lr.value(10000), 1e-5, rtol=1e-5) + + def test_decay_rate_override(self) -> None: + """Test explicit decay_rate parameter.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=1000, + decay_rate=0.9, + ) + self.assertEqual(lr.decay_rate, 0.9) + np.testing.assert_allclose(lr.value(1000), 1e-3 * 0.9, rtol=1e-10) + + +class TestLearningRateCosineBasic(unittest.TestCase): + """Test basic cosine annealing learning rate functionality.""" + + def test_basic_cosine(self) -> None: + """Test basic cosine annealing without warmup.""" + lr = LearningRateCosine( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + ) + np.testing.assert_allclose(lr.value(0), 1e-3, rtol=1e-10) + np.testing.assert_allclose(lr.value(10000), 1e-5, rtol=1e-10) + np.testing.assert_allclose(lr.value(5000), (1e-3 + 1e-5) / 2, rtol=1e-5) + + def test_stop_lr_ratio(self) -> None: + """Test stop_lr_ratio parameter.""" + lr = LearningRateCosine( + start_lr=1e-3, + stop_lr_ratio=0.01, + num_steps=10000, + ) + np.testing.assert_allclose(lr.stop_lr, 1e-5, rtol=1e-10) + + +class TestLearningRateWarmup(unittest.TestCase): + """Test learning rate warmup functionality.""" + + def test_warmup_steps_exp(self) -> None: + """Test warmup with exponential decay.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=1000, + warmup_steps=1000, + ) + self.assertEqual(lr.decay_num_steps, 9000) + np.testing.assert_allclose(lr.value(0), 0.0, rtol=1e-10) + np.testing.assert_allclose(lr.value(500), 0.5e-3, rtol=1e-10) + np.testing.assert_allclose(lr.value(1000), 1e-3, rtol=1e-10) + # Step 2000: 1000 steps into decay phase (1 decay period with decay_steps=1000) + # lr = start_lr * decay_rate^1 = 1e-3 * exp(log(0.01)/9) ≈ 5.995e-4 + np.testing.assert_allclose( + to_numpy_array(lr.value(2000)), 1e-3 * np.exp(np.log(0.01) / 9), rtol=1e-5 + ) + + def test_warmup_steps_cosine(self) -> None: + """Test warmup with cosine annealing.""" + lr = LearningRateCosine( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + warmup_steps=1000, + ) + self.assertEqual(lr.decay_num_steps, 9000) + np.testing.assert_allclose(lr.value(0), 0.0, rtol=1e-10) + np.testing.assert_allclose(lr.value(1000), 1e-3, rtol=1e-10) + np.testing.assert_allclose(lr.value(10000), 1e-5, rtol=1e-10) + + def test_warmup_ratio(self) -> None: + """Test warmup_ratio parameter.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=1000, + warmup_ratio=0.1, + ) + self.assertEqual(lr.warmup_steps, 1000) + self.assertEqual(lr.decay_num_steps, 9000) + + def test_warmup_start_factor(self) -> None: + """Test warmup_start_factor parameter.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=1000, + warmup_steps=1000, + warmup_start_factor=0.1, + ) + np.testing.assert_allclose(lr.value(0), 0.1e-3, rtol=1e-10) + np.testing.assert_allclose(lr.value(1000), 1e-3, rtol=1e-10) + + def test_no_warmup(self) -> None: + """Test that warmup_steps=0 works correctly.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=5000, + warmup_steps=0, + ) + self.assertEqual(lr.warmup_steps, 0) + self.assertEqual(lr.decay_num_steps, 10000) + np.testing.assert_allclose(lr.value(0), 1e-3, rtol=1e-10) + + +class TestLearningRateArrayInput(unittest.TestCase): + """Test learning rate with array inputs for JIT compatibility.""" + + def test_array_input_exp(self) -> None: + """Test exponential decay with array input.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=5000, + warmup_steps=1000, + ) + steps = np.array([0, 500, 1000, 5000, 10000]) + lrs = lr.value(steps) + self.assertEqual(lrs.shape, (5,)) + np.testing.assert_allclose(lrs[0], 0.0, rtol=1e-10) + np.testing.assert_allclose(lrs[2], 1e-3, rtol=1e-10) + + def test_array_input_cosine(self) -> None: + """Test cosine annealing with array input.""" + lr = LearningRateCosine( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + warmup_steps=1000, + ) + steps = np.array([0, 1000, 5500, 10000]) + lrs = lr.value(steps) + self.assertEqual(lrs.shape, (4,)) + np.testing.assert_allclose(lrs[0], 0.0, rtol=1e-10) + np.testing.assert_allclose(lrs[1], 1e-3, rtol=1e-10) + np.testing.assert_allclose(lrs[3], 1e-5, rtol=1e-10) + + +class TestLearningRateBeyondStopSteps(unittest.TestCase): + """Test learning rate behavior beyond num_steps.""" + + def test_exp_beyond_num_steps(self) -> None: + """Test exponential decay clamps to stop_lr.""" + lr = LearningRateExp( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + decay_steps=1000, + ) + np.testing.assert_allclose(lr.value(20000), 1e-5, rtol=1e-10) + + def test_cosine_beyond_num_steps(self) -> None: + """Test cosine annealing returns stop_lr beyond decay phase.""" + lr = LearningRateCosine( + start_lr=1e-3, + stop_lr=1e-5, + num_steps=10000, + ) + np.testing.assert_allclose(lr.value(20000), 1e-5, rtol=1e-10)