diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py index 727d01921f..a308f2d36b 100644 --- a/deepmd/pt/utils/tabulate.py +++ b/deepmd/pt/utils/tabulate.py @@ -46,7 +46,8 @@ class DPTabulate(BaseTabulate): The excluded pairs of types which have no interaction with each other. For example, `[[0, 1]]` means no interaction between type 0 and type 1. activation_function - The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn. + The activation function in the embedding net. See :class:`ActivationFn` + for supported options (e.g. "tanh", "gelu", "relu", "silu"). """ def __init__( @@ -84,6 +85,7 @@ def __init__( "relu6": 4, "softplus": 5, "sigmoid": 6, + "silu": 7, } activation = activation_fn.activation @@ -468,6 +470,11 @@ def grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tensor: elif functype == 6: return y * (1 - y) + elif functype == 7: + # silu'(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + sig = torch.sigmoid(xbar) + return sig + xbar * sig * (1 - sig) + else: raise ValueError(f"Unsupported function type: {functype}") @@ -495,6 +502,12 @@ def grad_grad(xbar: torch.Tensor, y: torch.Tensor, functype: int) -> torch.Tenso elif functype == 6: return y * (1 - y) * (1 - 2 * y) + elif functype == 7: + sig = torch.sigmoid(xbar) + d_sig = sig * (1 - sig) + # silu''(x) = 2 * d_sig + x * d_sig * (1 - 2 * sig) + return 2 * d_sig + xbar * d_sig * (1 - 2 * sig) + else: return -torch.ones_like(xbar) diff --git a/source/op/tf/unaggregated_grad.cc b/source/op/tf/unaggregated_grad.cc index cf645f6c21..329e25b2d2 100644 --- a/source/op/tf/unaggregated_grad.cc +++ b/source/op/tf/unaggregated_grad.cc @@ -75,6 +75,10 @@ FPTYPE grad(const FPTYPE xbar, case 6: { return y * (1 - y); } + case 7: { + const FPTYPE sig = 1.0 / (1.0 + exp(-xbar)); + return sig + xbar * sig * (1 - sig); + } default: return -1; } @@ -105,6 +109,11 @@ FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype) { case 6: { return y * (1 - y) * (1 - 2 * y); } + case 7: { + const FPTYPE sig = 1.0 / (1.0 + exp(-xbar)); + const FPTYPE d_sig = sig * (1 - sig); + return 2 * d_sig + xbar * d_sig * (1 - 2 * sig); + } default: return -1; } diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py index 164819408f..21083f7bdd 100644 --- a/source/tests/pt/test_tabulate.py +++ b/source/tests/pt/test_tabulate.py @@ -4,6 +4,9 @@ import numpy as np import torch +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) from deepmd.pt.utils import ( env, ) @@ -18,6 +21,24 @@ tf, ) +ACTIVATION_NAMES = { + 1: "tanh", + 2: "gelu", + 3: "relu", + 4: "relu6", + 5: "softplus", + 6: "sigmoid", + 7: "silu", +} + + +def get_activation_function(functype: int): + """Get activation function corresponding to functype.""" + if functype not in ACTIVATION_NAMES: + raise ValueError(f"Unknown functype: {functype}") + + return get_activation_fn(ACTIVATION_NAMES[functype]) + def setUpModule() -> None: tf.compat.v1.enable_eager_execution() @@ -43,92 +64,129 @@ def setUp(self) -> None: self.xbar = np.matmul(self.x, self.w) + self.b # 4 x 4 - self.y = np.tanh(self.xbar) - def test_ops(self) -> None: + """Test all activation functions using parameterized subtests.""" + for functype in ACTIVATION_NAMES.keys(): + activation_name = ACTIVATION_NAMES[functype] + activation_fn = get_activation_function(functype) + + with self.subTest(activation=activation_name, functype=functype): + self._test_single_activation(functype, activation_fn, activation_name) + + def _test_single_activation( + self, functype: int, activation_fn, activation_name: str + ) -> None: + """Test tabulation operations for a specific activation function.""" + # Compute y using the specific activation function + y = activation_fn(self.xbar) + + # Test unaggregated_dy_dx_s dy_tf = op_module.unaggregated_dy_dx_s( - tf.constant(self.y, dtype="double"), + tf.constant(y, dtype="double"), tf.constant(self.w, dtype="double"), tf.constant(self.xbar, dtype="double"), - tf.constant(1), + tf.constant(functype), ) dy_pt = unaggregated_dy_dx_s( - torch.from_numpy(self.y), + torch.from_numpy(y), self.w, torch.from_numpy(self.xbar), - 1, + functype, ) dy_tf_numpy = dy_tf.numpy() dy_pt_numpy = dy_pt.detach().cpu().numpy() - np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10) + np.testing.assert_almost_equal( + dy_tf_numpy, + dy_pt_numpy, + decimal=10, + err_msg=f"unaggregated_dy_dx_s failed for {activation_name}", + ) + # Test unaggregated_dy2_dx_s dy2_tf = op_module.unaggregated_dy2_dx_s( - tf.constant(self.y, dtype="double"), + tf.constant(y, dtype="double"), dy_tf, tf.constant(self.w, dtype="double"), tf.constant(self.xbar, dtype="double"), - tf.constant(1), + tf.constant(functype), ) dy2_pt = unaggregated_dy2_dx_s( - torch.from_numpy(self.y), + torch.from_numpy(y), dy_pt, self.w, torch.from_numpy(self.xbar), - 1, + functype, ) dy2_tf_numpy = dy2_tf.numpy() dy2_pt_numpy = dy2_pt.detach().cpu().numpy() - np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + np.testing.assert_almost_equal( + dy2_tf_numpy, + dy2_pt_numpy, + decimal=10, + err_msg=f"unaggregated_dy2_dx_s failed for {activation_name}", + ) + # Test unaggregated_dy_dx dz_tf = op_module.unaggregated_dy_dx( - tf.constant(self.y, dtype="double"), + tf.constant(y, dtype="double"), tf.constant(self.w, dtype="double"), dy_tf, tf.constant(self.xbar, dtype="double"), - tf.constant(1), + tf.constant(functype), ) dz_pt = unaggregated_dy_dx( - torch.from_numpy(self.y).to(env.DEVICE), + torch.from_numpy(y).to(env.DEVICE), self.w, dy_pt, torch.from_numpy(self.xbar).to(env.DEVICE), - 1, + functype, ) dz_tf_numpy = dz_tf.numpy() dz_pt_numpy = dz_pt.detach().cpu().numpy() - np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10) + np.testing.assert_almost_equal( + dz_tf_numpy, + dz_pt_numpy, + decimal=10, + err_msg=f"unaggregated_dy_dx failed for {activation_name}", + ) + # Test unaggregated_dy2_dx dy2_tf = op_module.unaggregated_dy2_dx( - tf.constant(self.y, dtype="double"), + tf.constant(y, dtype="double"), tf.constant(self.w, dtype="double"), dy_tf, dy2_tf, tf.constant(self.xbar, dtype="double"), - tf.constant(1), + tf.constant(functype), ) dy2_pt = unaggregated_dy2_dx( - torch.from_numpy(self.y).to(env.DEVICE), + torch.from_numpy(y).to(env.DEVICE), self.w, dy_pt, dy2_pt, torch.from_numpy(self.xbar).to(env.DEVICE), - 1, + functype, ) dy2_tf_numpy = dy2_tf.numpy() dy2_pt_numpy = dy2_pt.detach().cpu().numpy() - np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + np.testing.assert_almost_equal( + dy2_tf_numpy, + dy2_pt_numpy, + decimal=10, + err_msg=f"unaggregated_dy2_dx failed for {activation_name}", + ) if __name__ == "__main__":