Skip to content

Commit d5590a4

Browse files
wanghan-iapcmHan Wang
andauthored
fix: model check assumes __call__ as the forward method (#3136)
- add `__call__` method for `NativeOP`. - adapt UTs accordingly. Co-authored-by: Han Wang <[email protected]>
1 parent 828df66 commit d5590a4

File tree

4 files changed

+38
-28
lines changed

4 files changed

+38
-28
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .common import (
33
DEFAULT_PRECISION,
44
PRECISION_DICT,
5+
NativeOP,
56
)
67
from .env_mat import (
78
EnvMat,
@@ -34,6 +35,7 @@
3435
"NativeLayer",
3536
"NativeNet",
3637
"NetworkCollection",
38+
"NativeOP",
3739
"load_dp_model",
3840
"save_dp_model",
3941
"traverse_model_dict",

deepmd_utils/model_format/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ class NativeOP(ABC):
2222
def call(self, *args, **kwargs):
2323
"""Forward pass in NumPy implementation."""
2424
raise NotImplementedError
25+
26+
def __call__(self, *args, **kwargs):
27+
"""Forward pass in NumPy implementation."""
28+
return self.call(*args, **kwargs)

deepmd_utils/model_format/output_def.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def model_check_output(cls):
2727
2828
Two methods are assumed to be provided by the Model:
2929
1. Model.output_def that gives the output definition.
30-
2. Model.forward that defines the forward path of the model.
30+
2. Model.__call__ that defines the forward path of the model.
3131
3232
"""
3333

@@ -40,12 +40,12 @@ def __init__(
4040
super().__init__(*args, **kwargs)
4141
self.md = cls.output_def(self)
4242

43-
def forward(
43+
def __call__(
4444
self,
4545
*args,
4646
**kwargs,
4747
):
48-
ret = cls.forward(self, *args, **kwargs)
48+
ret = cls.__call__(self, *args, **kwargs)
4949
for kk in self.md.keys_outp():
5050
dd = self.md[kk]
5151
check_var(ret[kk], dd)
@@ -66,7 +66,7 @@ def fitting_check_output(cls):
6666
6767
Two methods are assumed to be provided by the Fitting:
6868
1. Fitting.output_def that gives the output definition.
69-
2. Fitting.forward defines the forward path of the fitting.
69+
2. Fitting.__call__ defines the forward path of the fitting.
7070
7171
"""
7272

@@ -79,12 +79,12 @@ def __init__(
7979
super().__init__(*args, **kwargs)
8080
self.md = cls.output_def(self)
8181

82-
def forward(
82+
def __call__(
8383
self,
8484
*args,
8585
**kwargs,
8686
):
87-
ret = cls.forward(self, *args, **kwargs)
87+
ret = cls.__call__(self, *args, **kwargs)
8888
for kk in self.md.keys():
8989
dd = self.md[kk]
9090
check_var(ret[kk], dd)

source/tests/test_output_def.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from deepmd_utils.model_format import (
77
FittingOutputDef,
88
ModelOutputDef,
9+
NativeOP,
910
OutputVariableDef,
1011
fitting_check_output,
1112
model_check_output,
@@ -91,14 +92,14 @@ def test_model_decorator(self):
9192
nloc = 3
9293

9394
@model_check_output
94-
class Foo:
95+
class Foo(NativeOP):
9596
def output_def(self):
9697
defs = [
9798
OutputVariableDef("energy", [1], True, True),
9899
]
99100
return ModelOutputDef(FittingOutputDef(defs))
100101

101-
def forward(self):
102+
def call(self):
102103
return {
103104
"energy": np.zeros([nf, nloc, 1]),
104105
"energy_redu": np.zeros([nf, 1]),
@@ -107,21 +108,24 @@ def forward(self):
107108
}
108109

109110
ff = Foo()
110-
ff.forward()
111+
ff()
111112

112113
def test_model_decorator_keyerror(self):
113114
nf = 2
114115
nloc = 3
115116

116117
@model_check_output
117-
class Foo:
118+
class Foo(NativeOP):
119+
def __init__(self):
120+
super().__init__()
121+
118122
def output_def(self):
119123
defs = [
120124
OutputVariableDef("energy", [1], True, True),
121125
]
122126
return ModelOutputDef(FittingOutputDef(defs))
123127

124-
def forward(self):
128+
def call(self):
125129
return {
126130
"energy": np.zeros([nf, nloc, 1]),
127131
"energy_redu": np.zeros([nf, 1]),
@@ -130,15 +134,15 @@ def forward(self):
130134

131135
ff = Foo()
132136
with self.assertRaises(KeyError) as context:
133-
ff.forward()
137+
ff()
134138
self.assertIn("energy_derv_r", context.exception)
135139

136140
def test_model_decorator_shapeerror(self):
137141
nf = 2
138142
nloc = 3
139143

140144
@model_check_output
141-
class Foo:
145+
class Foo(NativeOP):
142146
def __init__(
143147
self,
144148
shape_rd=[nf, 1],
@@ -152,7 +156,7 @@ def output_def(self):
152156
]
153157
return ModelOutputDef(FittingOutputDef(defs))
154158

155-
def forward(self):
159+
def call(self):
156160
return {
157161
"energy": np.zeros([nf, nloc, 1]),
158162
"energy_redu": np.zeros(self.shape_rd),
@@ -161,56 +165,56 @@ def forward(self):
161165
}
162166

163167
ff = Foo()
164-
ff.forward()
168+
ff()
165169
# shape of reduced energy
166170
with self.assertRaises(ValueError) as context:
167171
ff = Foo(shape_rd=[nf, nloc, 1])
168-
ff.forward()
172+
ff()
169173
self.assertIn("not matching", context.exception)
170174
with self.assertRaises(ValueError) as context:
171175
ff = Foo(shape_rd=[nf, 2])
172-
ff.forward()
176+
ff()
173177
self.assertIn("not matching", context.exception)
174178
# shape of dr
175179
with self.assertRaises(ValueError) as context:
176180
ff = Foo(shape_dr=[nf, nloc, 1])
177-
ff.forward()
181+
ff()
178182
self.assertIn("not matching", context.exception)
179183
with self.assertRaises(ValueError) as context:
180184
ff = Foo(shape_dr=[nf, nloc, 1, 3, 3])
181-
ff.forward()
185+
ff()
182186
self.assertIn("not matching", context.exception)
183187
with self.assertRaises(ValueError) as context:
184188
ff = Foo(shape_dr=[nf, nloc, 1, 4])
185-
ff.forward()
189+
ff()
186190
self.assertIn("not matching", context.exception)
187191

188192
def test_fitting_decorator(self):
189193
nf = 2
190194
nloc = 3
191195

192196
@fitting_check_output
193-
class Foo:
197+
class Foo(NativeOP):
194198
def output_def(self):
195199
defs = [
196200
OutputVariableDef("energy", [1], True, True),
197201
]
198202
return FittingOutputDef(defs)
199203

200-
def forward(self):
204+
def call(self):
201205
return {
202206
"energy": np.zeros([nf, nloc, 1]),
203207
}
204208

205209
ff = Foo()
206-
ff.forward()
210+
ff()
207211

208212
def test_fitting_decorator_shapeerror(self):
209213
nf = 2
210214
nloc = 3
211215

212216
@fitting_check_output
213-
class Foo:
217+
class Foo(NativeOP):
214218
def __init__(
215219
self,
216220
shape=[nf, nloc, 1],
@@ -223,19 +227,19 @@ def output_def(self):
223227
]
224228
return FittingOutputDef(defs)
225229

226-
def forward(self):
230+
def call(self):
227231
return {
228232
"energy": np.zeros(self.shape),
229233
}
230234

231235
ff = Foo()
232-
ff.forward()
236+
ff()
233237
# shape of reduced energy
234238
with self.assertRaises(ValueError) as context:
235239
ff = Foo(shape=[nf, 1])
236-
ff.forward()
240+
ff()
237241
self.assertIn("not matching", context.exception)
238242
with self.assertRaises(ValueError) as context:
239243
ff = Foo(shape=[nf, nloc, 2])
240-
ff.forward()
244+
ff()
241245
self.assertIn("not matching", context.exception)

0 commit comments

Comments
 (0)