Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion deepmd/pt/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class DPTabulate(BaseTabulate):
The excluded pairs of types which have no interaction with each other.
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
activation_function
The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn.
The activation function in the embedding net. See :class:`ActivationFn`
for supported options (e.g. "tanh", "gelu", "relu", "silu").
"""

def __init__(
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
"relu6": 4,
"softplus": 5,
"sigmoid": 6,
"silu": 7,
}

activation = activation_fn.activation
Expand Down Expand Up @@ -468,6 +470,11 @@ def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tensor:
elif functype == 6:
return y * (1 - y)

elif functype == 7:
# silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
sig = torch.sigmoid(xbar)
return sig + xbar * sig * (1 - sig)

else:
raise ValueError(f"Unsupported function type: {functype}")

Expand Down Expand Up @@ -495,6 +502,12 @@ def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tenso
elif functype == 6:
return y * (1 - y) * (1 - 2 * y)

elif functype == 7:
sig = torch.sigmoid(xbar)
d_sig = sig * (1 - sig)
# silu''(x) = 2 * d_sig + x * d_sig * (1 - 2 * sig)
return 2 * d_sig + xbar * d_sig * (1 - 2 * sig)

else:
return -torch.ones_like(xbar)

Expand Down
9 changes: 9 additions & 0 deletions source/op/tf/unaggregated_grad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ FPTYPE grad(const FPTYPE xbar,
case 6: {
return y * (1 - y);
}
case 7: {
const FPTYPE sig = 1.0 / (1.0 + exp(-xbar));
return sig + xbar * sig * (1 - sig);
}
default:
return -1;
}
Expand Down Expand Up @@ -105,6 +109,11 @@ FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype) {
case 6: {
return y * (1 - y) * (1 - 2 * y);
}
case 7: {
const FPTYPE sig = 1.0 / (1.0 + exp(-xbar));
const FPTYPE d_sig = sig * (1 - sig);
return 2 * d_sig + xbar * d_sig * (1 - 2 * sig);
}
default:
return -1;
}
Expand Down
116 changes: 94 additions & 22 deletions source/tests/pt/test_tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,41 @@
)


def get_activation_function(functype: int):
"""Get activation function corresponding to functype."""
if functype == 1:
return lambda x: np.tanh(x)
elif functype == 2:
return (
lambda x: 0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))
)
elif functype == 3:
return lambda x: np.maximum(x, 0)
elif functype == 4:
return lambda x: np.minimum(np.maximum(x, 0), 6)
elif functype == 5:
return lambda x: np.log(1 + np.exp(x))
elif functype == 6:
return lambda x: 1 / (1 + np.exp(-x))
elif functype == 7:
return lambda x: x / (1 + np.exp(-x))
else:
raise ValueError(f"Unknown functype: {functype}")


ACTIVATION_NAMES = {
1: "tanh",
2: "gelu",
3: "relu",
4: "relu6",
5: "softplus",
6: "sigmoid",
7: "silu",
}


def setUpModule() -> None:
tf.compat.v1.enable_eager_execution()

Expand All @@ -43,92 +78,129 @@

self.xbar = np.matmul(self.x, self.w) + self.b # 4 x 4

self.y = np.tanh(self.xbar)

def test_ops(self) -> None:
"""Test all activation functions using parameterized subtests."""
for functype in ACTIVATION_NAMES.keys():
activation_name = ACTIVATION_NAMES[functype]
activation_fn = get_activation_function(functype)

with self.subTest(activation=activation_name, functype=functype):
self._test_single_activation(functype, activation_fn, activation_name)

def _test_single_activation(
self, functype: int, activation_fn, activation_name: str
) -> None:
"""Test tabulation operations for a specific activation function."""
# Compute y using the specific activation function
y = activation_fn(self.xbar)

# Test unaggregated_dy_dx_s
dy_tf = op_module.unaggregated_dy_dx_s(
tf.constant(self.y, dtype="double"),
tf.constant(y, dtype="double"),
tf.constant(self.w, dtype="double"),
tf.constant(self.xbar, dtype="double"),
tf.constant(1),
tf.constant(functype),
)

dy_pt = unaggregated_dy_dx_s(
torch.from_numpy(self.y),
torch.from_numpy(y),
self.w,
torch.from_numpy(self.xbar),
1,
functype,
)

dy_tf_numpy = dy_tf.numpy()
dy_pt_numpy = dy_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10)
np.testing.assert_almost_equal(
dy_tf_numpy,
dy_pt_numpy,
decimal=10,
err_msg=f"unaggregated_dy_dx_s failed for {activation_name}",
)

# Test unaggregated_dy2_dx_s
dy2_tf = op_module.unaggregated_dy2_dx_s(
tf.constant(self.y, dtype="double"),
tf.constant(y, dtype="double"),
dy_tf,
tf.constant(self.w, dtype="double"),
tf.constant(self.xbar, dtype="double"),
tf.constant(1),
tf.constant(functype),
)

dy2_pt = unaggregated_dy2_dx_s(
torch.from_numpy(self.y),
torch.from_numpy(y),
dy_pt,
self.w,
torch.from_numpy(self.xbar),
1,
functype,
)

dy2_tf_numpy = dy2_tf.numpy()
dy2_pt_numpy = dy2_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)
np.testing.assert_almost_equal(
dy2_tf_numpy,
dy2_pt_numpy,
decimal=10,
err_msg=f"unaggregated_dy2_dx_s failed for {activation_name}",
)

# Test unaggregated_dy_dx
dz_tf = op_module.unaggregated_dy_dx(
tf.constant(self.y, dtype="double"),
tf.constant(y, dtype="double"),
tf.constant(self.w, dtype="double"),
dy_tf,
tf.constant(self.xbar, dtype="double"),
tf.constant(1),
tf.constant(functype),
)

dz_pt = unaggregated_dy_dx(
torch.from_numpy(self.y).to(env.DEVICE),
torch.from_numpy(y).to(env.DEVICE),
self.w,
dy_pt,
torch.from_numpy(self.xbar).to(env.DEVICE),
1,
functype,
)

dz_tf_numpy = dz_tf.numpy()
dz_pt_numpy = dz_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10)
np.testing.assert_almost_equal(
dz_tf_numpy,
dz_pt_numpy,
decimal=10,
err_msg=f"unaggregated_dy_dx failed for {activation_name}",
)

# Test unaggregated_dy2_dx
dy2_tf = op_module.unaggregated_dy2_dx(
tf.constant(self.y, dtype="double"),
tf.constant(y, dtype="double"),
tf.constant(self.w, dtype="double"),
dy_tf,
dy2_tf,
tf.constant(self.xbar, dtype="double"),
tf.constant(1),
tf.constant(functype),
)

dy2_pt = unaggregated_dy2_dx(
torch.from_numpy(self.y).to(env.DEVICE),
torch.from_numpy(y).to(env.DEVICE),
self.w,
dy_pt,
dy2_pt,
torch.from_numpy(self.xbar).to(env.DEVICE),
1,
functype,
)

dy2_tf_numpy = dy2_tf.numpy()
dy2_pt_numpy = dy2_pt.detach().cpu().numpy()

np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10)
np.testing.assert_almost_equal(
dy2_tf_numpy,
dy2_pt_numpy,
decimal=10,
err_msg=f"unaggregated_dy2_dx failed for {activation_name}",
)


if __name__ == "__main__":
Expand Down
Loading