Skip to content

Commit d4e9ffc

Browse files
authored
feat(pt): Add support for SiLU activation function in gradient calculations (#5055)
- Introduced the SiLU (Sigmoid Linear Unit) activation function with corresponding gradient and second derivative calculations. - Updated the activation function mapping to include SiLU, enhancing the flexibility of activation functions available in the DPTabulate class. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for the SiLU (Swish) activation across runtime computations and gradients. * **Tests** * Expanded test coverage to validate all supported activations (tanh, gelu, relu, relu6, softplus, sigmoid, silu) and their first/second derivatives across execution paths. * **Documentation** * Updated activation references and lists to include SiLU among supported options. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c346332 commit d4e9ffc

File tree

3 files changed

+103
-23
lines changed

3 files changed

+103
-23
lines changed

deepmd/pt/utils/tabulate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class DPTabulate(BaseTabulate):
4646
The excluded pairs of types which have no interaction with each other.
4747
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
4848
activation_function
49-
The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn.
49+
The activation function in the embedding net. See :class:`ActivationFn`
50+
for supported options (e.g. "tanh", "gelu", "relu", "silu").
5051
"""
5152

5253
def __init__(
@@ -84,6 +85,7 @@ def __init__(
8485
"relu6": 4,
8586
"softplus": 5,
8687
"sigmoid": 6,
88+
"silu": 7,
8789
}
8890

8991
activation = activation_fn.activation
@@ -468,6 +470,11 @@ def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tensor:
468470
elif functype == 6:
469471
return y * (1 - y)
470472

473+
elif functype == 7:
474+
# silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
475+
sig = torch.sigmoid(xbar)
476+
return sig + xbar * sig * (1 - sig)
477+
471478
else:
472479
raise ValueError(f"Unsupported function type: {functype}")
473480

@@ -495,6 +502,12 @@ def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tenso
495502
elif functype == 6:
496503
return y * (1 - y) * (1 - 2 * y)
497504

505+
elif functype == 7:
506+
sig = torch.sigmoid(xbar)
507+
d_sig = sig * (1 - sig)
508+
# silu''(x) = 2 * d_sig + x * d_sig * (1 - 2 * sig)
509+
return 2 * d_sig + xbar * d_sig * (1 - 2 * sig)
510+
498511
else:
499512
return -torch.ones_like(xbar)
500513

source/op/tf/unaggregated_grad.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ FPTYPE grad(const FPTYPE xbar,
7575
case 6: {
7676
return y * (1 - y);
7777
}
78+
case 7: {
79+
const FPTYPE sig = 1.0 / (1.0 + exp(-xbar));
80+
return sig + xbar * sig * (1 - sig);
81+
}
7882
default:
7983
return -1;
8084
}
@@ -105,6 +109,11 @@ FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype) {
105109
case 6: {
106110
return y * (1 - y) * (1 - 2 * y);
107111
}
112+
case 7: {
113+
const FPTYPE sig = 1.0 / (1.0 + exp(-xbar));
114+
const FPTYPE d_sig = sig * (1 - sig);
115+
return 2 * d_sig + xbar * d_sig * (1 - 2 * sig);
116+
}
108117
default:
109118
return -1;
110119
}

source/tests/pt/test_tabulate.py

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import numpy as np
55
import torch
66

7+
from deepmd.dpmodel.utils.network import (
8+
get_activation_fn,
9+
)
710
from deepmd.pt.utils import (
811
env,
912
)
@@ -18,6 +21,24 @@
1821
tf,
1922
)
2023

24+
ACTIVATION_NAMES = {
25+
1: "tanh",
26+
2: "gelu",
27+
3: "relu",
28+
4: "relu6",
29+
5: "softplus",
30+
6: "sigmoid",
31+
7: "silu",
32+
}
33+
34+
35+
def get_activation_function(functype: int):
36+
"""Get activation function corresponding to functype."""
37+
if functype not in ACTIVATION_NAMES:
38+
raise ValueError(f"Unknown functype: {functype}")
39+
40+
return get_activation_fn(ACTIVATION_NAMES[functype])
41+
2142

2243
def setUpModule() -> None:
2344
tf.compat.v1.enable_eager_execution()
@@ -43,92 +64,129 @@ def setUp(self) -> None:
4364

4465
self.xbar = np.matmul(self.x, self.w) + self.b # 4 x 4
4566

46-
self.y = np.tanh(self.xbar)
47-
4867
def test_ops(self) -> None:
68+
"""Test all activation functions using parameterized subtests."""
69+
for functype in ACTIVATION_NAMES.keys():
70+
activation_name = ACTIVATION_NAMES[functype]
71+
activation_fn = get_activation_function(functype)
72+
73+
with self.subTest(activation=activation_name, functype=functype):
74+
self._test_single_activation(functype, activation_fn, activation_name)
75+
76+
def _test_single_activation(
77+
self, functype: int, activation_fn, activation_name: str
78+
) -> None:
79+
"""Test tabulation operations for a specific activation function."""
80+
# Compute y using the specific activation function
81+
y = activation_fn(self.xbar)
82+
83+
# Test unaggregated_dy_dx_s
4984
dy_tf = op_module.unaggregated_dy_dx_s(
50-
tf.constant(self.y, dtype="double"),
85+
tf.constant(y, dtype="double"),
5186
tf.constant(self.w, dtype="double"),
5287
tf.constant(self.xbar, dtype="double"),
53-
tf.constant(1),
88+
tf.constant(functype),
5489
)
5590

5691
dy_pt = unaggregated_dy_dx_s(
57-
torch.from_numpy(self.y),
92+
torch.from_numpy(y),
5893
self.w,
5994
torch.from_numpy(self.xbar),
60-
1,
95+
functype,
6196
)
6297

6398
dy_tf_numpy = dy_tf.numpy()
6499
dy_pt_numpy = dy_pt.detach().cpu().numpy()
65100

66-
np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10)
101+
np.testing.assert_almost_equal(
102+
dy_tf_numpy,
103+
dy_pt_numpy,
104+
decimal=10,
105+
err_msg=f"unaggregated_dy_dx_s failed for {activation_name}",
106+
)
67107

108+
# Test unaggregated_dy2_dx_s
68109
dy2_tf = op_module.unaggregated_dy2_dx_s(
69-
tf.constant(self.y, dtype="double"),
110+
tf.constant(y, dtype="double"),
70111
dy_tf,
71112
tf.constant(self.w, dtype="double"),
72113
tf.constant(self.xbar, dtype="double"),
73-
tf.constant(1),
114+
tf.constant(functype),
74115
)
75116

76117
dy2_pt = unaggregated_dy2_dx_s(
77-
torch.from_numpy(self.y),
118+
torch.from_numpy(y),
78119
dy_pt,
79120
self.w,
80121
torch.from_numpy(self.xbar),
81-
1,
122+
functype,
82123
)
83124

84125
dy2_tf_numpy = dy2_tf.numpy()
85126
dy2_pt_numpy = dy2_pt.detach().cpu().numpy()
86127

87-
np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)
128+
np.testing.assert_almost_equal(
129+
dy2_tf_numpy,
130+
dy2_pt_numpy,
131+
decimal=10,
132+
err_msg=f"unaggregated_dy2_dx_s failed for {activation_name}",
133+
)
88134

135+
# Test unaggregated_dy_dx
89136
dz_tf = op_module.unaggregated_dy_dx(
90-
tf.constant(self.y, dtype="double"),
137+
tf.constant(y, dtype="double"),
91138
tf.constant(self.w, dtype="double"),
92139
dy_tf,
93140
tf.constant(self.xbar, dtype="double"),
94-
tf.constant(1),
141+
tf.constant(functype),
95142
)
96143

97144
dz_pt = unaggregated_dy_dx(
98-
torch.from_numpy(self.y).to(env.DEVICE),
145+
torch.from_numpy(y).to(env.DEVICE),
99146
self.w,
100147
dy_pt,
101148
torch.from_numpy(self.xbar).to(env.DEVICE),
102-
1,
149+
functype,
103150
)
104151

105152
dz_tf_numpy = dz_tf.numpy()
106153
dz_pt_numpy = dz_pt.detach().cpu().numpy()
107154

108-
np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10)
155+
np.testing.assert_almost_equal(
156+
dz_tf_numpy,
157+
dz_pt_numpy,
158+
decimal=10,
159+
err_msg=f"unaggregated_dy_dx failed for {activation_name}",
160+
)
109161

162+
# Test unaggregated_dy2_dx
110163
dy2_tf = op_module.unaggregated_dy2_dx(
111-
tf.constant(self.y, dtype="double"),
164+
tf.constant(y, dtype="double"),
112165
tf.constant(self.w, dtype="double"),
113166
dy_tf,
114167
dy2_tf,
115168
tf.constant(self.xbar, dtype="double"),
116-
tf.constant(1),
169+
tf.constant(functype),
117170
)
118171

119172
dy2_pt = unaggregated_dy2_dx(
120-
torch.from_numpy(self.y).to(env.DEVICE),
173+
torch.from_numpy(y).to(env.DEVICE),
121174
self.w,
122175
dy_pt,
123176
dy2_pt,
124177
torch.from_numpy(self.xbar).to(env.DEVICE),
125-
1,
178+
functype,
126179
)
127180

128181
dy2_tf_numpy = dy2_tf.numpy()
129182
dy2_pt_numpy = dy2_pt.detach().cpu().numpy()
130183

131-
np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)
184+
np.testing.assert_almost_equal(
185+
dy2_tf_numpy,
186+
dy2_pt_numpy,
187+
decimal=10,
188+
err_msg=f"unaggregated_dy2_dx failed for {activation_name}",
189+
)
132190

133191

134192
if __name__ == "__main__":

0 commit comments

Comments
 (0)