diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index fe8f1499..e603b55e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -6,7 +6,7 @@ # Standard from enum import Enum -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Tuple # Third Party from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -210,6 +210,18 @@ class TrainingArgs(BaseModel): description="Number of samples the model should see before saving a checkpoint. Consider this to be the checkpoint save frequency. If --save_samples<=0, this feature is disabled.", ) learning_rate: float + adamw_weight_decay: float = Field( + default=0.0, + description="Weight decay coefficient for AdamW optimizer.", + ) + adamw_betas: Tuple[float, float] = Field( + default=(0.9, 0.95), + description="Beta coefficients (beta1, beta2) for AdamW optimizer.", + ) + adamw_eps: float = Field( + default=1e-8, + description="Epsilon for numerical stability in AdamW optimizer.", + ) warmup_steps: int = Field( default=0, description="Number of warmup steps to run before starting the main training loop.", diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index b73a86f3..4d7ef81a 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -419,6 +419,9 @@ def main(args): cpu_offload=args.cpu_offload_optimizer, name=None, # choose based on backend learning_rate=args.learning_rate, + betas=(args.adamw_beta1, args.adamw_beta2), + weight_decay=args.adamw_weight_decay, + eps=args.adamw_eps, ) accelerator.prepare_with_optimizer( optimizer=optimizer, @@ -526,6 +529,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--log_level={train_args.log_level}", f"--max_batch_len={train_args.max_batch_len}", f"--seed={train_args.random_seed}", + f"--adamw_weight_decay={train_args.adamw_weight_decay}", + f"--adamw_beta1={train_args.adamw_betas[0]}", + f"--adamw_beta2={train_args.adamw_betas[1]}", + f"--adamw_eps={train_args.adamw_eps}", ] ) @@ -817,6 +824,30 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Use Liger kernels for training.", ) + parser.add_argument( + "--adamw_weight_decay", + type=float, + default=0.0, + help="Weight decay coefficient for AdamW optimizer.", + ) + parser.add_argument( + "--adamw_beta1", + type=float, + default=0.9, + help="Beta1 coefficient for AdamW optimizer.", + ) + parser.add_argument( + "--adamw_beta2", + type=float, + default=0.95, + help="Beta2 coefficient for AdamW optimizer.", + ) + parser.add_argument( + "--adamw_eps", + type=float, + default=1e-8, + help="Epsilon for numerical stability in AdamW optimizer.", + ) args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/model.py b/src/instructlab/training/model.py index bb89a1c4..3fdb8d8a 100644 --- a/src/instructlab/training/model.py +++ b/src/instructlab/training/model.py @@ -512,6 +512,8 @@ def setup_optimizer( name: Optimizer | None, learning_rate: int, betas: Tuple[float, float] = (0.9, 0.95), + weight_decay: float = 0.0, + eps: float = 1e-8, ) -> torch.optim.Optimizer: """Setup and return an optimizer based on the given parameters. @@ -521,6 +523,8 @@ def setup_optimizer( name: Optional optimizer name to use learning_rate: Learning rate for the optimizer betas: Beta parameters for Adam optimizers + weight_decay: Weight decay coefficient for AdamW + eps: Epsilon for numerical stability Returns: A PyTorch optimizer instance @@ -557,8 +561,11 @@ def setup_optimizer( ) factory = functools.partial( - optimizer_cls, trainable_params, lr=learning_rate, betas=betas + optimizer_cls, + trainable_params, + lr=learning_rate, + betas=betas, + eps=eps, + weight_decay=weight_decay, ) - if optimizer_cls is AdamW: - return factory(weight_decay=0.0) return factory()