44import numpy as np
55import torch
66
7+ from deepmd .dpmodel .utils .network import (
8+ get_activation_fn ,
9+ )
710from deepmd .pt .utils import (
811 env ,
912)
1821 tf ,
1922)
2023
24+ ACTIVATION_NAMES = {
25+ 1 : "tanh" ,
26+ 2 : "gelu" ,
27+ 3 : "relu" ,
28+ 4 : "relu6" ,
29+ 5 : "softplus" ,
30+ 6 : "sigmoid" ,
31+ 7 : "silu" ,
32+ }
33+
34+
35+ def get_activation_function (functype : int ):
36+ """Get activation function corresponding to functype."""
37+ if functype not in ACTIVATION_NAMES :
38+ raise ValueError (f"Unknown functype: { functype } " )
39+
40+ return get_activation_fn (ACTIVATION_NAMES [functype ])
41+
2142
2243def setUpModule () -> None :
2344 tf .compat .v1 .enable_eager_execution ()
@@ -43,92 +64,129 @@ def setUp(self) -> None:
4364
4465 self .xbar = np .matmul (self .x , self .w ) + self .b # 4 x 4
4566
46- self .y = np .tanh (self .xbar )
47-
4867 def test_ops (self ) -> None :
68+ """Test all activation functions using parameterized subtests."""
69+ for functype in ACTIVATION_NAMES .keys ():
70+ activation_name = ACTIVATION_NAMES [functype ]
71+ activation_fn = get_activation_function (functype )
72+
73+ with self .subTest (activation = activation_name , functype = functype ):
74+ self ._test_single_activation (functype , activation_fn , activation_name )
75+
76+ def _test_single_activation (
77+ self , functype : int , activation_fn , activation_name : str
78+ ) -> None :
79+ """Test tabulation operations for a specific activation function."""
80+ # Compute y using the specific activation function
81+ y = activation_fn (self .xbar )
82+
83+ # Test unaggregated_dy_dx_s
4984 dy_tf = op_module .unaggregated_dy_dx_s (
50- tf .constant (self . y , dtype = "double" ),
85+ tf .constant (y , dtype = "double" ),
5186 tf .constant (self .w , dtype = "double" ),
5287 tf .constant (self .xbar , dtype = "double" ),
53- tf .constant (1 ),
88+ tf .constant (functype ),
5489 )
5590
5691 dy_pt = unaggregated_dy_dx_s (
57- torch .from_numpy (self . y ),
92+ torch .from_numpy (y ),
5893 self .w ,
5994 torch .from_numpy (self .xbar ),
60- 1 ,
95+ functype ,
6196 )
6297
6398 dy_tf_numpy = dy_tf .numpy ()
6499 dy_pt_numpy = dy_pt .detach ().cpu ().numpy ()
65100
66- np .testing .assert_almost_equal (dy_tf_numpy , dy_pt_numpy , decimal = 10 )
101+ np .testing .assert_almost_equal (
102+ dy_tf_numpy ,
103+ dy_pt_numpy ,
104+ decimal = 10 ,
105+ err_msg = f"unaggregated_dy_dx_s failed for { activation_name } " ,
106+ )
67107
108+ # Test unaggregated_dy2_dx_s
68109 dy2_tf = op_module .unaggregated_dy2_dx_s (
69- tf .constant (self . y , dtype = "double" ),
110+ tf .constant (y , dtype = "double" ),
70111 dy_tf ,
71112 tf .constant (self .w , dtype = "double" ),
72113 tf .constant (self .xbar , dtype = "double" ),
73- tf .constant (1 ),
114+ tf .constant (functype ),
74115 )
75116
76117 dy2_pt = unaggregated_dy2_dx_s (
77- torch .from_numpy (self . y ),
118+ torch .from_numpy (y ),
78119 dy_pt ,
79120 self .w ,
80121 torch .from_numpy (self .xbar ),
81- 1 ,
122+ functype ,
82123 )
83124
84125 dy2_tf_numpy = dy2_tf .numpy ()
85126 dy2_pt_numpy = dy2_pt .detach ().cpu ().numpy ()
86127
87- np .testing .assert_almost_equal (dy2_tf_numpy , dy2_pt_numpy , decimal = 10 )
128+ np .testing .assert_almost_equal (
129+ dy2_tf_numpy ,
130+ dy2_pt_numpy ,
131+ decimal = 10 ,
132+ err_msg = f"unaggregated_dy2_dx_s failed for { activation_name } " ,
133+ )
88134
135+ # Test unaggregated_dy_dx
89136 dz_tf = op_module .unaggregated_dy_dx (
90- tf .constant (self . y , dtype = "double" ),
137+ tf .constant (y , dtype = "double" ),
91138 tf .constant (self .w , dtype = "double" ),
92139 dy_tf ,
93140 tf .constant (self .xbar , dtype = "double" ),
94- tf .constant (1 ),
141+ tf .constant (functype ),
95142 )
96143
97144 dz_pt = unaggregated_dy_dx (
98- torch .from_numpy (self . y ).to (env .DEVICE ),
145+ torch .from_numpy (y ).to (env .DEVICE ),
99146 self .w ,
100147 dy_pt ,
101148 torch .from_numpy (self .xbar ).to (env .DEVICE ),
102- 1 ,
149+ functype ,
103150 )
104151
105152 dz_tf_numpy = dz_tf .numpy ()
106153 dz_pt_numpy = dz_pt .detach ().cpu ().numpy ()
107154
108- np .testing .assert_almost_equal (dz_tf_numpy , dz_pt_numpy , decimal = 10 )
155+ np .testing .assert_almost_equal (
156+ dz_tf_numpy ,
157+ dz_pt_numpy ,
158+ decimal = 10 ,
159+ err_msg = f"unaggregated_dy_dx failed for { activation_name } " ,
160+ )
109161
162+ # Test unaggregated_dy2_dx
110163 dy2_tf = op_module .unaggregated_dy2_dx (
111- tf .constant (self . y , dtype = "double" ),
164+ tf .constant (y , dtype = "double" ),
112165 tf .constant (self .w , dtype = "double" ),
113166 dy_tf ,
114167 dy2_tf ,
115168 tf .constant (self .xbar , dtype = "double" ),
116- tf .constant (1 ),
169+ tf .constant (functype ),
117170 )
118171
119172 dy2_pt = unaggregated_dy2_dx (
120- torch .from_numpy (self . y ).to (env .DEVICE ),
173+ torch .from_numpy (y ).to (env .DEVICE ),
121174 self .w ,
122175 dy_pt ,
123176 dy2_pt ,
124177 torch .from_numpy (self .xbar ).to (env .DEVICE ),
125- 1 ,
178+ functype ,
126179 )
127180
128181 dy2_tf_numpy = dy2_tf .numpy ()
129182 dy2_pt_numpy = dy2_pt .detach ().cpu ().numpy ()
130183
131- np .testing .assert_almost_equal (dy2_tf_numpy , dy2_pt_numpy , decimal = 10 )
184+ np .testing .assert_almost_equal (
185+ dy2_tf_numpy ,
186+ dy2_pt_numpy ,
187+ decimal = 10 ,
188+ err_msg = f"unaggregated_dy2_dx failed for { activation_name } " ,
189+ )
132190
133191
134192if __name__ == "__main__" :
0 commit comments