-
Notifications
You must be signed in to change notification settings - Fork 584
feat(pt): add cosine annealing lr scheduler #5133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for cosine annealing learning rate scheduling to the PyTorch backend. The implementation provides an alternative to the existing exponential decay scheduler with a standard cosine annealing formula that smoothly decreases the learning rate from start_lr to stop_lr over the training period.
Key changes:
- Added
LearningRateCosineclass implementing cosine annealing with formula:lr = stop_lr + (start_lr - stop_lr) * 0.5 * (1 + cos(π * step / stop_steps)) - Extended configuration schema to accept "cosine" as a learning rate type option (PyTorch-only)
- Refactored training logic to support multiple learning rate scheduler types with improved error handling
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| deepmd/dpmodel/utils/learning_rate.py | Implements the core LearningRateCosine class with cosine annealing formula |
| deepmd/utils/argcheck.py | Adds configuration arguments for cosine scheduler with start_lr and stop_lr parameters |
| deepmd/pt/utils/learning_rate.py | Exports LearningRateCosine for PyTorch backend |
| deepmd/pt/train/training.py | Refactors get_lr function to handle both exponential and cosine schedulers dynamically |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
📝 WalkthroughWalkthroughAdded a cosine-annealing learning-rate scheduler (LearningRateCosine), integrated it into the training LR factory and public PT API, extended argument validation to accept a cosine variant, and added unit tests to verify cosine schedule behavior. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Trainer as Trainer (training.py)
participant LRFactory as get_lr()
participant LRClass as LearningRateCosine / LearningRateExp
participant Optim as Optimizer
User->>Trainer: start training
Trainer->>LRFactory: build lr from lr_params (type, start_lr, stop_steps, stop_lr/stop_lr_factor)
alt type == "cosine"
LRFactory->>LRClass: instantiate LearningRateCosine(config)
else type == "exp"
LRFactory->>LRClass: instantiate LearningRateExp(config)
end
loop per training step
Trainer->>LRClass: lr = value(step)
LRClass-->>Trainer: lr (np.float64)
Trainer->>Optim: set lr and step optimizer
Optim-->>Trainer: step result
end
Note right of LRClass `#D6EAF8`: Cosine annealing computed as\nstop_lr + (start_lr - stop_lr) * 0.5*(1+cos(pi*clamped_step/stop_steps))
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
source/tests/pt/test_lr.py (1)
106-120: Consider expanding test coverage to match the comprehensiveness of existing tests.While the basic curve test validates key behaviors (start, end, plateau, mid-point), it tests only a single parameter configuration with hardcoded values. The existing
TestLearningRateclass (lines 18-104) demonstrates more comprehensive testing with:
- Multiple parameter combinations using
np.arange- Edge case validation
- Consistency checks across different configurations
Consider adding:
- Tests with varied
start_lr,stop_lr, andstop_stepscombinations- Edge cases:
stop_steps=1, very largestop_steps,stepexceedingstop_steps- Verification that the curve is monotonically decreasing (for
start_lr > stop_lr)- More intermediate points to verify the smoothness of the cosine curve
💡 Example: Enhanced test with multiple configurations
def test_multiple_configurations(self) -> None: """Test cosine annealing with various parameter combinations.""" start_lrs = [1.0, 0.01, 0.001] stop_lrs = [0.1, 0.0001, 1e-8] stop_steps_list = [10, 100, 1000] for start_lr in start_lrs: for stop_lr in stop_lrs: if stop_lr >= start_lr: continue for stop_steps in stop_steps_list: lr = LearningRateCosine(start_lr, stop_lr, stop_steps) # Verify boundary conditions self.assertTrue(np.allclose(lr.value(0), start_lr)) self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr)) # Verify monotonic decrease vals = [lr.value(i) for i in range(stop_steps + 1)] self.assertTrue(all(vals[i] >= vals[i+1] for i in range(len(vals)-1)))deepmd/dpmodel/utils/learning_rate.py (1)
60-88: Consider adding input validation for robustness.While the current implementation handles the critical case of
stop_steps(clamping to 1), consider adding validation for other edge cases to improve robustness:
- Negative
stepvalues (currently would produce unexpected results)start_lrorstop_lrbeing non-positive (if that's invalid for your use case)start_lr < stop_lr(cosine would increase rather than decrease)This is not critical if the calling code guarantees valid inputs, but defensive validation can prevent subtle bugs.
💡 Example: Optional input validation
def __init__( self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any, ) -> None: """ Construct a cosine-annealed learning rate. Parameters ---------- start_lr The learning rate at the start of the training. stop_lr The desired learning rate at the end of the training. stop_steps The total training steps for learning rate scheduler. """ + if start_lr <= 0 or stop_lr <= 0: + raise ValueError("Learning rates must be positive") + if stop_steps <= 0: + raise ValueError("stop_steps must be positive") self.start_lr = start_lr self.stop_lr = stop_lr self.stop_steps = max(1, stop_steps)For the
valuemethod:def value(self, step: int) -> np.float64: """Get the learning rate at the given step.""" + if step < 0: + step = 0 clamped_step = min(step, self.stop_steps) cosine = 0.5 * (1.0 + np.cos(np.pi * clamped_step / self.stop_steps)) return np.float64(self.stop_lr + (self.start_lr - self.stop_lr) * cosine)
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/utils/learning_rate.pysource/tests/pt/test_lr.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
source/tests/pt/test_lr.py (1)
deepmd/dpmodel/utils/learning_rate.py (3)
LearningRateCosine(60-88)value(52-57)value(84-88)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
66-66: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (24)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
source/tests/pt/test_lr.py (1)
10-10: LGTM!The import is correctly added and necessary for the new test class.
deepmd/dpmodel/utils/learning_rate.py (2)
66-66: The unusedkwargsparameter is acceptable for API consistency.While Ruff flags this as unused, accepting
**kwargsmaintains API consistency withLearningRateExp.__init__(line 17) and provides forward compatibility for potential extensions without breaking existing code.
84-88: LGTM! The cosine annealing implementation is mathematically correct.The formula correctly implements cosine annealing:
- At
step=0: returnsstart_lr- At
step=stop_steps: returnsstop_lr- Smooth cosine interpolation in between
- Clamping prevents extrapolation beyond
stop_steps(plateau behavior)- Return type
np.float64follows the learning about NumPy scalar types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/utils/argcheck.py (1)
2512-2544: Document the precedence behavior when bothstop_lrandstop_lr_factorare provided.The documentation explains behavior when
stop_lris omitted, but doesn't clarify what happens when bothstop_lrandstop_lr_factorare provided. Based on the implementation indeepmd/dpmodel/utils/learning_rate.py(lines 87-94),stop_lrtakes precedence. This should be documented to avoid confusion.📝 Suggested documentation improvement
doc_stop_lr = "The desired learning rate at the end of the training." doc_stop_lr_factor = ( "The factor to scale the learning rate at the end of the training. " "The actual stop_lr is calculated as `start_lr * stop_lr_factor`. " - "If `stop_lr` is not provided, this option will be used." + "If `stop_lr` is provided, it takes precedence over this option." )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/dpmodel/utils/learning_rate.pydeepmd/utils/argcheck.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2024-10-15T22:22:24.889Z
Learnt from: njzjz
Repo: deepmodeling/deepmd-kit PR: 4219
File: deepmd/utils/learning_rate.py:48-53
Timestamp: 2024-10-15T22:22:24.889Z
Learning: Methods in `deepmd/utils/learning_rate.py` that return NumPy scalar types should have return type annotations using the corresponding NumPy types, such as `np.float64`.
Applied to files:
deepmd/dpmodel/utils/learning_rate.py
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/tf/utils/learning_rate.py (1)
start_lr(96-98)deepmd/pt/train/training.py (1)
step(752-1124)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
67-67: Unused method argument: kwargs
(ARG002)
92-94: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (16)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
🔇 Additional comments (3)
deepmd/utils/argcheck.py (1)
2546-2563: LGTM! Clean integration of the cosine learning rate variant.The cosine annealing option is correctly integrated into the existing learning rate configuration system, following the same pattern as the exponential variant. The PT-only restriction is properly documented.
deepmd/dpmodel/utils/learning_rate.py (2)
61-95: LGTM! The constructor validation logic is sound.The parameter validation correctly ensures either
stop_lrorstop_lr_factoris provided, withstop_lrtaking precedence when both are specified. The clamping ofstop_stepsto a minimum of 1 prevents division by zero in thevalue()method.The unused
**kwargsparameter (flagged by static analysis) is acceptable here—it maintains consistency withLearningRateExp.__init__and provides forward compatibility.
97-101: LGTM! The cosine annealing formula is mathematically correct.The implementation correctly computes cosine-annealed learning rates:
- At step 0: returns
start_lr- At
stop_steps: returnsstop_lr- Steps beyond
stop_stepsplateau atstop_lrdue to clampingThe return type
np.float64aligns with the existingLearningRateExp.value()method and follows established conventions for this module.Based on learnings, methods in this module returning NumPy scalar types should use
np.float64annotations.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5133 +/- ##
==========================================
- Coverage 82.15% 81.89% -0.26%
==========================================
Files 709 712 +3
Lines 72468 74560 +2092
Branches 3616 3615 -1
==========================================
+ Hits 59535 61063 +1528
- Misses 11769 12334 +565
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Reopened in #5142 with previous implementation. |
Summary by CodeRabbit
New Features
Chores
Tests
✏️ Tip: You can customize this high-level summary in your review settings.