Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Standard
from enum import Enum
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Tuple

# Third Party
from pydantic import BaseModel, ConfigDict, Field, model_validator
Expand Down Expand Up @@ -210,6 +210,18 @@ class TrainingArgs(BaseModel):
description="Number of samples the model should see before saving a checkpoint. Consider this to be the checkpoint save frequency. If --save_samples<=0, this feature is disabled.",
)
learning_rate: float
adamw_weight_decay: float = Field(
default=0.0,
description="Weight decay coefficient for AdamW optimizer.",
)
adamw_betas: Tuple[float, float] = Field(
default=(0.9, 0.95),
description="Beta coefficients (beta1, beta2) for AdamW optimizer.",
)
adamw_eps: float = Field(
default=1e-8,
description="Epsilon for numerical stability in AdamW optimizer.",
)
warmup_steps: int = Field(
default=0,
description="Number of warmup steps to run before starting the main training loop.",
Expand Down
31 changes: 31 additions & 0 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ def main(args):
cpu_offload=args.cpu_offload_optimizer,
name=None, # choose based on backend
learning_rate=args.learning_rate,
betas=(args.adamw_beta1, args.adamw_beta2),
weight_decay=args.adamw_weight_decay,
eps=args.adamw_eps,
)
accelerator.prepare_with_optimizer(
optimizer=optimizer,
Expand Down Expand Up @@ -526,6 +529,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--log_level={train_args.log_level}",
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
f"--adamw_weight_decay={train_args.adamw_weight_decay}",
f"--adamw_beta1={train_args.adamw_betas[0]}",
f"--adamw_beta2={train_args.adamw_betas[1]}",
f"--adamw_eps={train_args.adamw_eps}",
]
)

Expand Down Expand Up @@ -817,6 +824,30 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
action="store_true",
help="Use Liger kernels for training.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.0,
help="Weight decay coefficient for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta1 coefficient for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta2 coefficient for AdamW optimizer.",
)
parser.add_argument(
"--adamw_eps",
type=float,
default=1e-8,
help="Epsilon for numerical stability in AdamW optimizer.",
)
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
13 changes: 10 additions & 3 deletions src/instructlab/training/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def setup_optimizer(
name: Optimizer | None,
learning_rate: int,
betas: Tuple[float, float] = (0.9, 0.95),
weight_decay: float = 0.0,
eps: float = 1e-8,
) -> torch.optim.Optimizer:
"""Setup and return an optimizer based on the given parameters.

Expand All @@ -521,6 +523,8 @@ def setup_optimizer(
name: Optional optimizer name to use
learning_rate: Learning rate for the optimizer
betas: Beta parameters for Adam optimizers
weight_decay: Weight decay coefficient for AdamW
eps: Epsilon for numerical stability

Returns:
A PyTorch optimizer instance
Expand Down Expand Up @@ -557,8 +561,11 @@ def setup_optimizer(
)

factory = functools.partial(
optimizer_cls, trainable_params, lr=learning_rate, betas=betas
optimizer_cls,
trainable_params,
lr=learning_rate,
betas=betas,
eps=eps,
weight_decay=weight_decay,
)
if optimizer_cls is AdamW:
return factory(weight_decay=0.0)
return factory()
Loading