Skip to content

Conversation

@RobotSail
Copy link
Member

@RobotSail RobotSail commented Dec 17, 2025

Summary

  • Add support for configuring AdamW optimizer parameters through TrainingArgs
  • New fields: adamw_weight_decay, adamw_betas, adamw_eps
  • Users can now tune these hyperparameters when calling run_training()

Changes

  • config.py: Added 3 new fields to TrainingArgs
  • model.py: Updated setup_optimizer() to accept weight_decay and eps parameters
  • main_ds.py: Added CLI arguments and wired them through the command builder

Usage

train_args = TrainingArgs(
    # ... other args
    learning_rate=1e-5,
    adamw_weight_decay=0.01,
    adamw_betas=(0.9, 0.999),
    adamw_eps=1e-8,
)
run_training(torch_args, train_args)

Test plan

  • Verified TrainingArgs imports and instantiates correctly with new fields
  • Verified setup_optimizer() signature has new parameters
  • Verified CLI --help shows new arguments
  • Passed ruff check/format, pylint (10/10)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features
    • Added AdamW optimizer options: weight decay, beta1/beta2 momentum, and epsilon for numerical stability.
    • New command-line flags (--adamw_weight_decay, --adamw_beta1, --adamw_beta2, --adamw_eps) to control these settings during training.
    • Training configuration and runtime now accept and pass through customizable AdamW hyperparameters for finer optimization control.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link

coderabbitai bot commented Dec 17, 2025

Walkthrough

This PR adds configurable AdamW optimizer hyperparameters across the training stack: new TrainingArgs fields, new CLI flags in the training launcher, and extended optimizer setup to accept and apply weight_decay, betas, and eps (with defaults).

Changes

Cohort / File(s) Summary
AdamW Hyperparameter Configuration
src/instructlab/training/config.py
Added three new fields to TrainingArgs: adamw_weight_decay (float, default 0.0), adamw_betas (Tuple[float, float], default (0.9, 0.95)), and adamw_eps (float, default 1e-8); updated imports to include Tuple.
AdamW CLI Arguments & Propagation
src/instructlab/training/main_ds.py
Introduced CLI arguments (--adamw_weight_decay, --adamw_beta1, --adamw_beta2, --adamw_eps) and wired them through the training invocation and distributed launcher to pass hyperparameters downstream.
Optimizer Setup Integration
src/instructlab/training/model.py
Extended setup_optimizer() signature with optional weight_decay: float = 0.0 and eps: float = 1e-8; updated docstring and passed these parameters into the optimizer factory (betas also accepted).

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

  • Check consistency of CLI flag names vs. TrainingArgs field names and defaults.
  • Verify adamw_betas parsing/usage (tuple construction from --adamw_beta1/--adamw_beta2).
  • Confirm optimizer factory receives and applies weight_decay and eps as intended.

Poem

🐰 I nibble on configs, soft and light,
I tuck in betas, epsilon bright,
A little decay, a gentle tune,
The optimizer hums beneath the moon,
Hopping to train through the night ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately and concisely describes the main change: exposing AdamW optimizer parameters in the training API, which is the core focus of all three modified files.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feature/expose-adamw-params

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify mergify bot added the ci-failure label Dec 17, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3d05302 and 32fc10d.

📒 Files selected for processing (3)
  • src/instructlab/training/config.py (2 hunks)
  • src/instructlab/training/main_ds.py (3 hunks)
  • src/instructlab/training/model.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: unit: 3.12 on ubuntu-latest
  • GitHub Check: unit: 3.13 on ubuntu-latest
  • GitHub Check: unit: 3.11 on ubuntu-latest
  • GitHub Check: pylint
🔇 Additional comments (7)
src/instructlab/training/config.py (2)

9-9: LGTM: Import addition is correct.

The Tuple import is necessary for the type annotation of adamw_betas field added below.


213-224: Default beta2=0.95 for AdamW is appropriate for LLM training.

The configuration uses beta2=0.95, which differs from PyTorch's default of 0.999 but aligns with standard practices for LLM training. Llama 3 and Llama 3.2 training configurations officially use this same value for improved training stability. No changes needed.

src/instructlab/training/main_ds.py (3)

422-424: LGTM: AdamW parameters correctly wired to optimizer setup.

The new AdamW hyperparameters are properly extracted from the parsed arguments and passed to setup_optimizer(). The tuple unpacking of beta1 and beta2 is correct.


532-535: LGTM: Command builder correctly propagates AdamW hyperparameters.

The command construction properly extracts beta1 and beta2 from the adamw_betas tuple and formats all four AdamW parameters as CLI arguments for the subprocess.


827-850: LGTM: CLI arguments properly defined with consistent defaults.

The new CLI arguments are well-documented with help text and defaults that match the TrainingArgs configuration in config.py.

src/instructlab/training/model.py (2)

515-516: LGTM: Function signature properly extended.

The new weight_decay and eps parameters are correctly added with appropriate defaults that match the TrainingArgs configuration.


526-527: LGTM: Docstring updated to document new parameters.

The documentation correctly describes the new weight_decay and eps parameters.

Add support for configuring weight_decay, betas, and eps parameters
for the AdamW optimizer through TrainingArgs, allowing users to tune
these hyperparameters when calling run_training().

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <[email protected]>
@RobotSail RobotSail force-pushed the feature/expose-adamw-params branch from 32fc10d to 770b8c2 Compare December 17, 2025 21:12
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/instructlab/training/config.py (1)

213-224: Consider adding field validators for optimizer parameters.

The new TrainingArgs fields are well-structured with clear descriptions and consistent defaults. However, consider adding validators to ensure:

  • adamw_betas values are in the range (0, 1)
  • adamw_eps is positive
  • adamw_weight_decay is non-negative

Additionally, the same concern applies here: the default adamw_betas=(0.9, 0.95) uses beta2=0.95 instead of the PyTorch standard of 0.999.

Example validator:

from pydantic import field_validator

@field_validator('adamw_betas')
def validate_betas(cls, v):
    if not (0 < v[0] < 1 and 0 < v[1] < 1):
        raise ValueError('Beta values must be in range (0, 1)')
    return v

@field_validator('adamw_eps')
def validate_eps(cls, v):
    if v <= 0:
        raise ValueError('Epsilon must be positive')
    return v

@field_validator('adamw_weight_decay')
def validate_weight_decay(cls, v):
    if v < 0:
        raise ValueError('Weight decay must be non-negative')
    return v
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32fc10d and 770b8c2.

📒 Files selected for processing (3)
  • src/instructlab/training/config.py (2 hunks)
  • src/instructlab/training/main_ds.py (3 hunks)
  • src/instructlab/training/model.py (3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: unit: 3.12 on ubuntu-latest
  • GitHub Check: unit: 3.11 on ubuntu-latest
  • GitHub Check: unit: 3.13 on ubuntu-latest
  • GitHub Check: pylint
  • GitHub Check: Summary
🔇 Additional comments (6)
src/instructlab/training/model.py (2)

509-527: LGTM! Past review comment addressed.

The function signature now includes weight_decay and eps parameters, enabling AdamW optimizer configuration. The defaults (weight_decay=0.0, eps=1e-8) are conservative and appropriate.


563-571: Optimizer factory correctly passes consistent parameters across all three optimizer types.

The uniform application of weight_decay and eps to AdamW, FusedAdam, and DeepSpeedCPUAdam via functools.partial is valid—all three optimizers support these parameters with compatible signatures.

src/instructlab/training/main_ds.py (3)

417-425: LGTM! AdamW parameters correctly wired.

The optimizer setup correctly passes the AdamW hyperparameters with betas properly constructed from the individual beta1 and beta2 CLI arguments.


532-536: LGTM! CLI arguments correctly constructed.

The torchrun command builder correctly includes all AdamW parameters, properly decomposing the betas tuple into individual beta1 and beta2 arguments.


827-850: The adamw_* parameters use non-standard defaults that reflect language model training practices rather than PyTorch defaults. The beta2=0.95 choice (vs. PyTorch's 0.999) and weight_decay=0.0 align with modern LLM training best practices used in frameworks like Composer and recent research. Consider adding documentation explaining why these LLM-optimized defaults were chosen if this choice isn't already documented elsewhere in the codebase.

src/instructlab/training/config.py (1)

9-9: LGTM! Necessary import for tuple type.

Added Tuple import to support the adamw_betas field type annotation.

Copy link
Collaborator

@Maxusmusti Maxusmusti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mergify mergify bot added the one-approval label Dec 18, 2025
@RobotSail RobotSail merged commit c495035 into main Dec 18, 2025
14 of 18 checks passed
@RobotSail RobotSail deleted the feature/expose-adamw-params branch December 18, 2025 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants