-
Notifications
You must be signed in to change notification settings - Fork 586
feat(pt/dp): add cosine LR & BaseLR #5142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughIntroduce a learning-rate abstraction Changes
Sequence Diagram(s)sequenceDiagram
actor Trainer
participant Config
participant LRFactory
participant LRSchedule
Trainer->>Config: read lr type & params
Trainer->>LRFactory: get_lr(lr_params)
LRFactory-->>LRSchedule: instantiate schedule via BaseLR registry
Trainer->>LRSchedule: request value(step)
LRSchedule-->>Trainer: return lr_value
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pt/train/training.py (1)
270-279: Well-structured learning rate scheduler dispatch.The implementation cleanly supports both exponential and cosine learning rate schedules with appropriate error handling for unsupported types. The dispatch logic is clear and the integration is correct.
Optional: Consider extracting the error message
Per the static analysis hint, the error message could be extracted to reduce line length:
+UNSUPPORTED_LR_TYPE_MSG = "Not supported learning rate type '{}'!" + def get_lr(lr_params: dict[str, Any]) -> LearningRateExp: lr_type = lr_params.get("type", "exp") lr_params["stop_steps"] = self.num_steps - self.warmup_steps if lr_type == "exp": lr_schedule = LearningRateExp(**lr_params) elif lr_type == "cosine": lr_schedule = LearningRateCosine(**lr_params) else: - raise ValueError(f"Not supported learning rate type '{lr_type}'!") + raise ValueError(UNSUPPORTED_LR_TYPE_MSG.format(lr_type)) return lr_scheduleThis is a minor style improvement and not critical.
deepmd/dpmodel/utils/learning_rate.py (1)
87-95: Cosine annealing formula is mathematically correct.The
value(step)method correctly implements cosine annealing:
- At step=0: returns
start_lr✓- At step=stop_steps: returns
stop_lr✓- Between 0 and stop_steps: smoothly interpolates following a cosine curve ✓
- Beyond stop_steps: maintains
stop_lr✓The formula
start_lr * (lr_min_factor + 0.5 * (1 - lr_min_factor) * (1 + cos(π * step / stop_steps)))produces the expected smooth annealing behavior.Optional: Consider adding input validation
For improved robustness, you could add validation in
__init__to prevent division-by-zero edge cases:def __init__( self, start_lr: float, stop_lr: float, stop_steps: int, **kwargs: Any, ) -> None: """...""" if start_lr <= 0: raise ValueError(f"start_lr must be positive, got {start_lr}") if stop_steps <= 0: raise ValueError(f"stop_steps must be positive, got {stop_steps}") self.start_lr = start_lr self.lr_min_factor = stop_lr / start_lr self.stop_steps = stop_stepsHowever, note that
LearningRateExpalso lacks such validation, so this is a codebase-wide pattern. Adding validation here would be an enhancement but is not critical for this PR.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
deepmd/dpmodel/utils/learning_rate.pydeepmd/pt/train/training.pydeepmd/pt/utils/learning_rate.pydeepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (3)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/tf/utils/learning_rate.py (1)
start_lr(96-98)deepmd/pt/train/training.py (1)
step(763-1135)
deepmd/pt/train/training.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
LearningRateCosine(60-95)LearningRateExp(9-57)
deepmd/pt/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
LearningRateCosine(60-95)LearningRateExp(9-57)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
66-66: Unused method argument: kwargs
(ARG002)
deepmd/pt/train/training.py
278-278: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cpu, cpu)
🔇 Additional comments (5)
deepmd/pt/utils/learning_rate.py (1)
1-10: LGTM! Clean re-export of the cosine learning rate scheduler.The import and export of
LearningRateCosinefollows the existing pattern forLearningRateExp, properly exposing the cosine annealing scheduler from the dpmodel layer to the PyTorch training interface.deepmd/pt/train/training.py (1)
66-66: Import added correctly.The
LearningRateCosineimport is properly placed alongside the existingLearningRateExpimport.deepmd/utils/argcheck.py (2)
2512-2527: Cosine learning rate argument definition is well-structured.The
learning_rate_cosine()function properly defines the configuration arguments for cosine annealing. The simpler parameter set (onlystart_lrandstop_lr, withoutdecay_stepsanddecay_rate) is appropriate for cosine annealing, which follows a predefined smooth curve over the total training steps.Note: The default
stop_lr=1e-5for cosine (vs.1e-8for exponential) represents a less aggressive decay, which is a reasonable design choice for the smoother cosine schedule.
2529-2541: Variant integration is correct.The cosine learning rate option is properly exposed in the variant type arguments, enabling users to configure
type: cosinein their training configuration.deepmd/dpmodel/utils/learning_rate.py (1)
60-86: Cosine annealing learning rate class is well-implemented.The
LearningRateCosineclass correctly implements the cosine annealing schedule. The initialization properly stores the required parameters and pre-computeslr_min_factorfor efficiency. The interface is consistent withLearningRateExp, including thestart_lrattribute (required for warmup in training) and the**kwargsparameter for API extensibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
source/tests/pt/test_lr.py (1)
106-120: Consider expanding test coverage with parameterized tests.The test correctly validates the basic cosine annealing behavior. To align with the existing
TestLearningRatepattern and improve robustness, consider adding:
- Parameterized tests with multiple combinations of
start_lr,stop_lr, andstop_steps(similar to lines 19-31)- Edge cases:
start_lr == stop_lr,stop_steps == 1, very smallstop_lrvalues- Additional intermediate points to verify the cosine curve shape more thoroughly
💡 Example expansion
def test_basic_curve(self) -> None: start_lr = 1.0 stop_lr = 0.1 stop_steps = 10 lr = LearningRateCosine(start_lr, stop_lr, stop_steps) self.assertTrue(np.allclose(lr.value(0), start_lr)) self.assertTrue(np.allclose(lr.value(stop_steps), stop_lr)) self.assertTrue(np.allclose(lr.value(stop_steps + 5), stop_lr)) mid_step = stop_steps // 2 expected_mid = stop_lr + (start_lr - stop_lr) * 0.5 self.assertTrue(np.allclose(lr.value(mid_step), expected_mid)) def test_edge_cases(self) -> None: # Test equal start and stop lr_flat = LearningRateCosine(0.001, 0.001, 100) self.assertTrue(np.allclose(lr_flat.value(0), 0.001)) self.assertTrue(np.allclose(lr_flat.value(50), 0.001)) # Test single step lr_single = LearningRateCosine(1.0, 0.1, 1) self.assertTrue(np.allclose(lr_single.value(0), 1.0)) self.assertTrue(np.allclose(lr_single.value(1), 0.1))
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/utils/argcheck.pysource/tests/pt/test_lr.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/utils/argcheck.py
🧰 Additional context used
🧬 Code graph analysis (1)
source/tests/pt/test_lr.py (1)
deepmd/dpmodel/utils/learning_rate.py (1)
LearningRateCosine(60-95)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (41)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (javascript-typescript)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
🔇 Additional comments (1)
source/tests/pt/test_lr.py (1)
10-10: LGTM!The import is correctly added and follows the existing pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
deepmd/dpmodel/utils/learning_rate.py (1)
90-124: Cosine annealing implementation looks correct.The formula correctly implements cosine annealing: starting at
start_lr, reachingstop_lratstop_steps, with smooth cosine interpolation.One minor observation: Line 118 could be simplified to
return self.stop_lrsinceself.start_lr * self.lr_min_factorequalsstop_lrby definition. However, keeping it as-is maintains consistency with the formula in the main branch.Optional simplification
def value(self, step: int) -> np.float64: if step >= self.stop_steps: - return self.start_lr * self.lr_min_factor + return self.stop_lr return self.start_lr * ( self.lr_min_factor + 0.5 * (1 - self.lr_min_factor) * (1 + np.cos(np.pi * (step / self.stop_steps))) )
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/dpmodel/utils/learning_rate.py
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/dpmodel/utils/learning_rate.py (2)
deepmd/tf/utils/learning_rate.py (1)
start_lr(96-98)deepmd/pt/train/training.py (1)
step(763-1135)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
15-15: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Analyze (python)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-macosx_arm64
🔇 Additional comments (3)
deepmd/dpmodel/utils/learning_rate.py (3)
2-5: LGTM!Standard library imports for ABC pattern are correctly added.
13-36: LGTM!The abstract base class provides a clean contract for learning rate schedules. The
**kwargsparameter (flagged by Ruff) is intentionally included for forward compatibility, allowing subclasses to accept and pass additional parameters through the hierarchy.
39-87: LGTM!The refactoring correctly delegates common attributes to the base class while preserving the existing exponential decay logic. The
min_lrnow appropriately referencesself.stop_lrfrom the base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In @deepmd/dpmodel/utils/learning_rate.py:
- Around line 104-139: LearningRateCosine.__init__ currently computes
self.lr_min_factor = stop_lr / start_lr which will raise ZeroDivisionError if
start_lr is zero; add an explicit guard at the start of __init__ (e.g., if
start_lr == 0.0: raise ValueError("start_lr must be non-zero") or check
abs(start_lr) < eps) to raise a clear ValueError with a helpful message instead
of letting a ZeroDivisionError occur, then proceed to compute self.lr_min_factor
as before.
🧹 Nitpick comments (2)
deepmd/pt/train/training.py (1)
269-272: Consider renamingself.lr_exptoself.lr_schedulethroughout the class.The local variable is correctly named
lr_schedule(line 271), but the instance variableself.lr_expis used throughout the class (e.g., lines 434, 436, 438, 682, 690, 766, 769, 1165). Since this variable now holds anyBaseLRsubclass (not just exponential), consider renaming it for clarity.Example locations using the old name
# Line 434 self.lr_exp = {} # Line 682 return self.lr_exp.value(step - warmup_steps) / self.lr_exp.start_lr # Line 766 if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] else: _lr = self.lr_expdeepmd/dpmodel/utils/learning_rate.py (1)
21-49: Consider removing unusedargsparameter from__new__method.The
__new__method accepts*argsbut doesn't use it—only**kwargsis accessed for type resolution. While this might be intentional for a standard signature, the static analysis warning is valid.Proposed adjustment
class BaseLR(ABC, PluginVariant, make_plugin_registry("lr")): - def __new__(cls: type, *args: Any, **kwargs: Any) -> Any: + def __new__(cls: type, **kwargs: Any) -> Any: if cls is BaseLR: cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls)Note: If subclasses need positional arguments, keep
*argsfor compatibility. However, the current plugin pattern only uses keyword arguments.
📜 Review details
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/dpmodel/utils/learning_rate.pydeepmd/pt/train/training.pydeepmd/pt/utils/learning_rate.py
🧰 Additional context used
🧬 Code graph analysis (2)
deepmd/pt/utils/learning_rate.py (1)
deepmd/dpmodel/utils/learning_rate.py (2)
BaseLR(21-49)LearningRateCosine(105-139)
deepmd/pt/train/training.py (2)
deepmd/dpmodel/utils/learning_rate.py (1)
BaseLR(21-49)deepmd/pd/train/training.py (1)
get_lr(241-247)
🪛 Ruff (0.14.10)
deepmd/dpmodel/utils/learning_rate.py
22-22: Unused static method argument: args
(ARG004)
28-28: Unused method argument: kwargs
(ARG002)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (40)
- GitHub Check: Test Python (12, 3.10)
- GitHub Check: Test Python (3, 3.10)
- GitHub Check: Test Python (2, 3.10)
- GitHub Check: Test Python (1, 3.13)
- GitHub Check: Test Python (12, 3.13)
- GitHub Check: Test Python (11, 3.10)
- GitHub Check: Test Python (6, 3.10)
- GitHub Check: Test Python (9, 3.10)
- GitHub Check: Test Python (5, 3.10)
- GitHub Check: Test Python (10, 3.13)
- GitHub Check: Test Python (10, 3.10)
- GitHub Check: Test Python (4, 3.10)
- GitHub Check: Test Python (2, 3.13)
- GitHub Check: Test Python (9, 3.13)
- GitHub Check: Test Python (11, 3.13)
- GitHub Check: Test Python (7, 3.10)
- GitHub Check: Test Python (4, 3.13)
- GitHub Check: Test Python (8, 3.13)
- GitHub Check: Test Python (7, 3.13)
- GitHub Check: Test Python (5, 3.13)
- GitHub Check: Test Python (8, 3.10)
- GitHub Check: Test Python (3, 3.13)
- GitHub Check: Test Python (1, 3.10)
- GitHub Check: Test Python (6, 3.13)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true, true, true, false)
- GitHub Check: Test C++ (true, false, false, true)
- GitHub Check: Test C++ (false, false, false, true)
- GitHub Check: Test C++ (false, true, true, false)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
🔇 Additional comments (7)
deepmd/pt/utils/learning_rate.py (2)
2-6: LGTM! Clean re-export of new learning rate classes.The imports correctly bring in the new
BaseLRbase class andLearningRateCosinescheduler from the dpmodel layer, enabling the PyTorch training module to use the new cosine annealing schedule.
8-12: LGTM! Public API correctly updated.The
__all__list properly exposes the new learning rate classes alongside the existingLearningRateExp, maintaining backward compatibility while enabling the new cosine scheduler feature.deepmd/pt/train/training.py (2)
65-67: LGTM! Import updated to support the new learning rate abstraction.The change from importing
LearningRateExptoBaseLRcorrectly enables the generic learning rate scheduling mechanism, allowing both exponential and cosine schedulers to be used.
269-272: The plugin registry pattern correctly requires the "type" key inlr_params.Verification confirms that
BaseLR.__new__usesj_get_type()to extract the "type" key fromlr_paramsand resolve the appropriate learning rate subclass. If the "type" key is missing,j_get_type()raises a clear KeyError. Configuration files must provide the "type" key (e.g., "exp", "cosine") in thelearning_ratesection, as shown in example configs. The refactoring correctly leverages the plugin registry pattern and is properly designed.deepmd/dpmodel/utils/learning_rate.py (3)
2-18: LGTM! Necessary imports for the plugin-based learning rate framework.The imports correctly bring in the abstract base class utilities and the plugin infrastructure needed for dynamic learning rate scheduler registration and resolution.
52-101: LGTM! Clean refactoring of exponential decay as a BaseLR subclass.The
LearningRateExpclass is properly refactored to:
- Inherit from
BaseLRand callsuper().__init__()with the required parameters- Register with
@BaseLR.register("exp")for plugin resolution- Maintain backward compatibility with the existing exponential decay logic
131-139: LGTM! Cosine annealing formula is mathematically correct.The implementation correctly applies cosine annealing:
- At step 0: returns
start_lr(cos(0) = 1)- At stop_steps: returns
stop_lr(cos(π) = -1)- Beyond stop_steps: maintains
stop_lr- Smooth cosine decay in between
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5142 +/- ##
=======================================
Coverage 81.94% 81.95%
=======================================
Files 712 712
Lines 72887 72918 +31
Branches 3616 3616
=======================================
+ Hits 59725 59757 +32
Misses 11998 11998
+ Partials 1164 1163 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
OutisLi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe add an optional args named stop_lr_ratio will be nice. So we can simply set it to 0.01 and just change the start_lr after. Just like many papers.
@OutisLi Sure, it's a good idea and maybe you can add it in the next PR, as well as the warmup procedure. |
Summary by CodeRabbit
New Features
Configuration
Tests
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.