-
Notifications
You must be signed in to change notification settings - Fork 75
Expose AdamW optimizer parameters in training API #674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThis PR adds configurable AdamW optimizer hyperparameters across the training stack: new TrainingArgs fields, new CLI flags in the training launcher, and extended optimizer setup to accept and apply weight_decay, betas, and eps (with defaults). Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.py(2 hunks)src/instructlab/training/main_ds.py(3 hunks)src/instructlab/training/model.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: pylint
🔇 Additional comments (7)
src/instructlab/training/config.py (2)
9-9: LGTM: Import addition is correct.The
Tupleimport is necessary for the type annotation ofadamw_betasfield added below.
213-224: Defaultbeta2=0.95for AdamW is appropriate for LLM training.The configuration uses
beta2=0.95, which differs from PyTorch's default of0.999but aligns with standard practices for LLM training. Llama 3 and Llama 3.2 training configurations officially use this same value for improved training stability. No changes needed.src/instructlab/training/main_ds.py (3)
422-424: LGTM: AdamW parameters correctly wired to optimizer setup.The new AdamW hyperparameters are properly extracted from the parsed arguments and passed to
setup_optimizer(). The tuple unpacking of beta1 and beta2 is correct.
532-535: LGTM: Command builder correctly propagates AdamW hyperparameters.The command construction properly extracts
beta1andbeta2from theadamw_betastuple and formats all four AdamW parameters as CLI arguments for the subprocess.
827-850: LGTM: CLI arguments properly defined with consistent defaults.The new CLI arguments are well-documented with help text and defaults that match the
TrainingArgsconfiguration inconfig.py.src/instructlab/training/model.py (2)
515-516: LGTM: Function signature properly extended.The new
weight_decayandepsparameters are correctly added with appropriate defaults that match theTrainingArgsconfiguration.
526-527: LGTM: Docstring updated to document new parameters.The documentation correctly describes the new
weight_decayandepsparameters.
Add support for configuring weight_decay, betas, and eps parameters for the AdamW optimizer through TrainingArgs, allowing users to tune these hyperparameters when calling run_training(). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
32fc10d to
770b8c2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/instructlab/training/config.py (1)
213-224: Consider adding field validators for optimizer parameters.The new TrainingArgs fields are well-structured with clear descriptions and consistent defaults. However, consider adding validators to ensure:
adamw_betasvalues are in the range (0, 1)adamw_epsis positiveadamw_weight_decayis non-negativeAdditionally, the same concern applies here: the default
adamw_betas=(0.9, 0.95)uses beta2=0.95 instead of the PyTorch standard of 0.999.Example validator:
from pydantic import field_validator @field_validator('adamw_betas') def validate_betas(cls, v): if not (0 < v[0] < 1 and 0 < v[1] < 1): raise ValueError('Beta values must be in range (0, 1)') return v @field_validator('adamw_eps') def validate_eps(cls, v): if v <= 0: raise ValueError('Epsilon must be positive') return v @field_validator('adamw_weight_decay') def validate_weight_decay(cls, v): if v < 0: raise ValueError('Weight decay must be non-negative') return v
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.py(2 hunks)src/instructlab/training/main_ds.py(3 hunks)src/instructlab/training/model.py(3 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: unit: 3.12 on ubuntu-latest
- GitHub Check: unit: 3.11 on ubuntu-latest
- GitHub Check: unit: 3.13 on ubuntu-latest
- GitHub Check: pylint
- GitHub Check: Summary
🔇 Additional comments (6)
src/instructlab/training/model.py (2)
509-527: LGTM! Past review comment addressed.The function signature now includes
weight_decayandepsparameters, enabling AdamW optimizer configuration. The defaults (weight_decay=0.0, eps=1e-8) are conservative and appropriate.
563-571: Optimizer factory correctly passes consistent parameters across all three optimizer types.The uniform application of
weight_decayandepsto AdamW, FusedAdam, and DeepSpeedCPUAdam via functools.partial is valid—all three optimizers support these parameters with compatible signatures.src/instructlab/training/main_ds.py (3)
417-425: LGTM! AdamW parameters correctly wired.The optimizer setup correctly passes the AdamW hyperparameters with betas properly constructed from the individual beta1 and beta2 CLI arguments.
532-536: LGTM! CLI arguments correctly constructed.The torchrun command builder correctly includes all AdamW parameters, properly decomposing the betas tuple into individual beta1 and beta2 arguments.
827-850: Theadamw_*parameters use non-standard defaults that reflect language model training practices rather than PyTorch defaults. The beta2=0.95 choice (vs. PyTorch's 0.999) and weight_decay=0.0 align with modern LLM training best practices used in frameworks like Composer and recent research. Consider adding documentation explaining why these LLM-optimized defaults were chosen if this choice isn't already documented elsewhere in the codebase.src/instructlab/training/config.py (1)
9-9: LGTM! Necessary import for tuple type.Added
Tupleimport to support theadamw_betasfield type annotation.
Maxusmusti
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary
TrainingArgsadamw_weight_decay,adamw_betas,adamw_epsrun_training()Changes
config.py: Added 3 new fields toTrainingArgsmodel.py: Updatedsetup_optimizer()to acceptweight_decayandepsparametersmain_ds.py: Added CLI arguments and wired them through the command builderUsage
Test plan
TrainingArgsimports and instantiates correctly with new fieldssetup_optimizer()signature has new parameters--helpshows new arguments🤖 Generated with Claude Code
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.