-
Notifications
You must be signed in to change notification settings - Fork 576
feat(pt): Add support for SiLU activation function in gradient calculations #5055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced the SiLU (Sigmoid Linear Unit) activation function with corresponding gradient and second derivative calculations. - Updated the activation function mapping to include SiLU, enhancing the flexibility of activation functions available in the DPTabulate class.
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughAdds SiLU (swish) activation support as functype 7 across Python tabulation utilities, C++ unaggregated gradient code, and PyTorch tests; implements SiLU first- and second-derivative branches, updates activation mapping, and parameterizes tests to exercise activations 1–7. Changes
Sequence Diagram(s)sequenceDiagram
participant T as Tests
participant Py as Python tabulate
participant C as C++ unaggregated_grad
Note over T,Py: Iterate functype 1..7
T->>Py: activation = get_activation_function(functype)
T->>Py: y = activation(x)
T->>Py: request derivatives (unaggregated_dy_dx[_s], unaggregated_dy2_dx[_s]) with functype
Py->>C: dispatch grad/grad_grad using functype
alt functype == 7
C-->>Py: compute SiLU grad & grad_grad (sigmoid-based formulas)
else
C-->>Py: compute existing activation derivatives
end
Py-->>T: return derivative tensors for assertions
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example:
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pt/utils/tabulate.py (2)
48-49: Align activation docs andActivationFnstring with new SiLU supportYou map
"silu": 7correctly intoactivation_map, but the class docstring still claims only{"tanh","gelu"}are supported, and it doesn’t mention SiLU or the other already-supported activations. Also, please double‑check thatActivationFnelsewhere in the codebase accepts"silu"as the activation string (same spelling/casing) so this mapping is reachable.Consider updating the docstring to defer to
ActivationFnand include SiLU in the examples, e.g.:- activation_function - The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn. + activation_function + The activation function in the embedding net. See :class:`ActivationFn` + for supported options (e.g. "tanh", "gelu", "relu", "silu").Also applies to: 78-88
445-509: Add targeted tests for SiLU gradients and HessiansNow that
functype == 7is wired intograd/grad_grad, it would be good to add tests that:
- Compare
grad(xbar, y, 7)andgrad_grad(xbar, y, 7)against PyTorch autograd ontorch.nn.SiLUover a range ofxbarvalues.- Exercise both CPU and GPU (if supported in your CI).
This will guard against future regressions in the hand‑coded formulas.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/utils/tabulate.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
deepmd/pt/utils/tabulate.py (1)
deepmd/pt/utils/utils.py (1)
sigmoid(154-155)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (31)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
🔇 Additional comments (2)
deepmd/pt/utils/tabulate.py (2)
472-476: SiLU first derivative implementation looks correctThe
functype == 7branch implements
sig = torch.sigmoid(xbar)sig + xbar * sig * (1 - sig)which matches ( \frac{d}{dx} [x \cdot \sigma(x)] = \sigma(x) + x \sigma(x)(1 - \sigma(x)) ). Using a single
torch.sigmoidcall and reusingsigis efficient and consistent with the other branches.
504-508: SiLU second derivative implementation is mathematically consistentHere you compute:
sig = torch.sigmoid(xbar)d_sig = sig * (1 - sig)2 * d_sig + xbar * d_sig * (1 - 2 * sig)which matches the analytically derived ( f''(x) = 2,d_\sigma + x,d_\sigma(1 - 2\sigma(x)) ) for ( f(x) = x \sigma(x) ). This integrates cleanly with the existing
grad_gradinterface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 1 out of 1 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## devel #5055 +/- ##
==========================================
+ Coverage 84.18% 84.25% +0.07%
==========================================
Files 709 709
Lines 70220 70234 +14
Branches 3619 3620 +1
==========================================
+ Hits 59116 59177 +61
+ Misses 9936 9889 -47
Partials 1168 1168 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Updated the DPTabulate class documentation to include additional activation functions ("relu", "silu") in the embedding net.
- Added a new activation function (case 7: SiLU) and its gradient calculations in the unaggregated gradient functions.
- Implemented comprehensive tests for all activation functions, ensuring correct behavior across various scenarios in the test suite.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
deepmd/pt/utils/tabulate.py (1)
49-50: Docstring now lists supported activations but parameter name is slightly inconsistentThe updated docstring correctly points to
ActivationFnand lists"silu"among supported options. For clarity, consider aligning the documented parameter name (activation_function) with the actual__init__argument name (activation_fn).source/tests/pt/test_tabulate.py (1)
22-44: Activation helper and name map cover all functypes, including SiLU
get_activation_functionandACTIVATION_NAMEScleanly cover functypes 1–7, with the SiLU case implemented asx / (1 + exp(-x)), which matches the intendedx * sigmoid(x)behavior used elsewhere. This keeps the tests aligned with the activation map inDPTabulateand the new SiLU gradients. TheValueErroron unknownfunctypeis also fine here; the Ruff TRY003 hint is purely stylistic and can be ignored or suppressed if it’s noisy.Also applies to: 46-54
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/utils/tabulate.py(4 hunks)source/op/tf/unaggregated_grad.cc(2 hunks)source/tests/pt/test_tabulate.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
source/tests/pt/test_tabulate.py (1)
deepmd/pt/utils/tabulate.py (4)
unaggregated_dy_dx_s(515-534)unaggregated_dy2_dx_s(537-563)unaggregated_dy_dx(566-600)unaggregated_dy2_dx(603-645)
🪛 Ruff (0.14.4)
source/tests/pt/test_tabulate.py
43-43: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (31)
- GitHub Check: CodeQL analysis (python)
- GitHub Check: Agent
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Test C++ (false)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Test C++ (true)
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
🔇 Additional comments (3)
source/op/tf/unaggregated_grad.cc (1)
78-81: SiLU gradient and second derivative implementation look correct and consistentThe new
case 7branches ingradandgrad_gradimplement the first and second derivatives of SiLUf(x) = x * sigmoid(x)asσ + x σ (1-σ)and2 σ' + x σ' (1-2 σ), respectively, which are mathematically correct and consistent with the Python tabulate implementations and the new tests.Also applies to: 112-116
deepmd/pt/utils/tabulate.py (1)
80-89: SiLU activation mapping and derivatives are coherent across the stackAdding
"silu": 7toactivation_mapand the correspondingfunctype == 7branches ingradandgrad_gradprovides a consistent SiLU implementation:gradreturnsσ(x) + x σ(x) (1 - σ(x))andgrad_gradreturns2 σ'(x) + x σ'(x) (1 - 2 σ(x)), matching the analytic derivatives off(x)=x*sigmoid(x)and the C++UnaggregatedDyDx*implementations and tests.Also applies to: 473-477, 505-509
source/tests/pt/test_tabulate.py (1)
82-89: Parameterized TF/PT parity tests across all activations are well structuredThe refactored
test_opswith_test_single_activationsystematically exercises all four unaggregated ops (*_dy_dx_s,*_dy2_dx_s,*_dy_dx,*_dy2_dx) for functypes 1–7, comparing TF kernels against the PT utilities with clear per-activation subTest labels. Argument ordering and shapes match the C++ op signatures and Python wrappers, so this should give good coverage for the new SiLU path without altering existing behavior.Also applies to: 90-203
… but wrapped to use num index instead of names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
source/tests/pt/test_tabulate.py (1)
35-40: LGTM! Clean helper function.The
get_activation_functionhelper provides a clear interface for looking up activation functions by functype with appropriate error handling.Optional: The static analysis hint (TRY003) suggests defining the error message within a custom exception class for better exception handling practices, though this is a minor style preference:
- raise ValueError(f"Unknown functype: {functype}") + raise ValueError(f"Unknown functype: {functype}") # noqa: TRY003or define a custom exception if this pattern is common across the codebase.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_tabulate.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
source/tests/pt/test_tabulate.py (2)
deepmd/dpmodel/utils/network.py (1)
get_activation_fn(300-399)deepmd/pt/utils/tabulate.py (4)
unaggregated_dy_dx_s(515-534)unaggregated_dy2_dx_s(537-563)unaggregated_dy_dx(566-600)unaggregated_dy2_dx(603-645)
🪛 Ruff (0.14.4)
source/tests/pt/test_tabulate.py
38-38: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (29)
- GitHub Check: Build C library (2.14, >=2.5.0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Analyze (python)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (true)
- GitHub Check: Test C++ (false)
🔇 Additional comments (4)
source/tests/pt/test_tabulate.py (4)
7-9: LGTM! Good reuse of existing implementation.The import of
get_activation_fncorrectly leverages the existing activation function implementation from the codebase, addressing the previous review concern about using the existing implementation.
24-32: LGTM! Clear activation mapping.The
ACTIVATION_NAMESconstant provides a clear, maintainable mapping of functypes to activation names, including the newly added SiLU (functype 7).
67-74: LGTM! Excellent parameterized testing approach.The refactoring to use
subTestfor each activation function is a best practice that provides:
- Clear test isolation per activation
- Easy identification of failures by activation name and functype
- Maintainable test structure
76-189: Test structure is sound. TensorFlow ops support all functypes 1-7.Verification confirms that
source/op/tf/unaggregated_grad.ccimplements all required functype cases (1-7) in both gradient computation functions, including the newly added functype 7 (SiLU). The test logic is well-structured with proper device handling and will execute correctly.
Summary by CodeRabbit