Skip to content

Commit 5493921

Browse files
wanghan-iapcmHan Wangnjzjzpre-commit-ci[bot]
authored
fix: some issue of the output def (#3152)
- strict type hint - allow the last dim to be variable (by setting the dim to -1) - remove variable def, which is not very useful. - _derv_c should be defined for each atom --------- Co-authored-by: Han Wang <[email protected]> Co-authored-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1efc7f8 commit 5493921

File tree

3 files changed

+128
-64
lines changed

3 files changed

+128
-64
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
FittingOutputDef,
2525
ModelOutputDef,
2626
OutputVariableDef,
27-
VariableDef,
2827
fitting_check_output,
28+
get_deriv_name,
29+
get_reduce_name,
2930
model_check_output,
3031
)
3132
from .se_e2_a import (
@@ -52,7 +53,8 @@
5253
"ModelOutputDef",
5354
"FittingOutputDef",
5455
"OutputVariableDef",
55-
"VariableDef",
5656
"model_check_output",
5757
"fitting_check_output",
58+
"get_reduce_name",
59+
"get_deriv_name",
5860
]

deepmd_utils/model_format/output_def.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,34 @@
33
Dict,
44
List,
55
Tuple,
6-
Union,
76
)
87

98

9+
def check_shape(
10+
shape: List[int],
11+
def_shape: List[int],
12+
):
13+
"""Check if the shape satisfies the defined shape."""
14+
assert len(shape) == len(def_shape)
15+
if def_shape[-1] == -1:
16+
if list(shape[:-1]) != def_shape[:-1]:
17+
raise ValueError(f"{shape[:-1]} shape not matching def {def_shape[:-1]}")
18+
else:
19+
if list(shape) != def_shape:
20+
raise ValueError(f"{shape} shape not matching def {def_shape}")
21+
22+
1023
def check_var(var, var_def):
1124
if var_def.atomic:
1225
# var.shape == [nf, nloc, *var_def.shape]
1326
if len(var.shape) != len(var_def.shape) + 2:
1427
raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}")
15-
if list(var.shape[2:]) != var_def.shape:
16-
raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}")
28+
check_shape(list(var.shape[2:]), var_def.shape)
1729
else:
1830
# var.shape == [nf, *var_def.shape]
1931
if len(var.shape) != len(var_def.shape) + 1:
2032
raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}")
21-
if list(var.shape[1:]) != var_def.shape:
22-
raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}")
33+
check_shape(list(var.shape[1:]), var_def.shape)
2334

2435

2536
def model_check_output(cls):
@@ -38,7 +49,7 @@ def __init__(
3849
**kwargs,
3950
):
4051
super().__init__(*args, **kwargs)
41-
self.md = cls.output_def(self)
52+
self.md = self.output_def()
4253

4354
def __call__(
4455
self,
@@ -77,7 +88,7 @@ def __init__(
7788
**kwargs,
7889
):
7990
super().__init__(*args, **kwargs)
80-
self.md = cls.output_def(self)
91+
self.md = self.output_def()
8192

8293
def __call__(
8394
self,
@@ -93,35 +104,7 @@ def __call__(
93104
return wrapper
94105

95106

96-
class VariableDef:
97-
"""Defines the shape and other properties of a variable.
98-
99-
Parameters
100-
----------
101-
name
102-
Name of the output variable. Notice that the xxxx_redu,
103-
xxxx_derv_c, xxxx_derv_r are reserved names that should
104-
not be used to define variables.
105-
shape
106-
The shape of the variable. e.g. energy should be [1],
107-
dipole should be [3], polarizabilty should be [3,3].
108-
atomic
109-
If the variable is defined for each atom.
110-
111-
"""
112-
113-
def __init__(
114-
self,
115-
name: str,
116-
shape: Union[List[int], Tuple[int]],
117-
atomic: bool = True,
118-
):
119-
self.name = name
120-
self.shape = list(shape)
121-
self.atomic = atomic
122-
123-
124-
class OutputVariableDef(VariableDef):
107+
class OutputVariableDef:
125108
"""Defines the shape and other properties of the one output variable.
126109
127110
It is assume that the fitting network output variables for each
@@ -149,12 +132,14 @@ class OutputVariableDef(VariableDef):
149132
def __init__(
150133
self,
151134
name: str,
152-
shape: Union[List[int], Tuple[int]],
135+
shape: List[int],
153136
reduciable: bool = False,
154137
differentiable: bool = False,
138+
atomic: bool = True,
155139
):
156-
# fitting output must be atomic
157-
super().__init__(name, shape, atomic=True)
140+
self.name = name
141+
self.shape = list(shape)
142+
self.atomic = atomic
158143
self.reduciable = reduciable
159144
self.differentiable = differentiable
160145
if not self.reduciable and self.differentiable:
@@ -176,13 +161,13 @@ class FittingOutputDef:
176161

177162
def __init__(
178163
self,
179-
var_defs: List[OutputVariableDef] = [],
164+
var_defs: List[OutputVariableDef],
180165
):
181166
self.var_defs = {vv.name: vv for vv in var_defs}
182167

183168
def __getitem__(
184169
self,
185-
key,
170+
key: str,
186171
) -> OutputVariableDef:
187172
return self.var_defs[key]
188173

@@ -215,7 +200,7 @@ def __init__(
215200
self.def_outp = fit_defs
216201
self.def_redu = do_reduce(self.def_outp)
217202
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
218-
self.var_defs = {}
203+
self.var_defs: Dict[str, OutputVariableDef] = {}
219204
for ii in [
220205
self.def_outp.get_data(),
221206
self.def_redu,
@@ -224,10 +209,16 @@ def __init__(
224209
]:
225210
self.var_defs.update(ii)
226211

227-
def __getitem__(self, key) -> VariableDef:
212+
def __getitem__(
213+
self,
214+
key: str,
215+
) -> OutputVariableDef:
228216
return self.var_defs[key]
229217

230-
def get_data(self, key) -> Dict[str, VariableDef]:
218+
def get_data(
219+
self,
220+
key: str,
221+
) -> Dict[str, OutputVariableDef]:
231222
return self.var_defs
232223

233224
def keys(self):
@@ -246,33 +237,45 @@ def keys_derv_c(self):
246237
return self.def_derv_c.keys()
247238

248239

249-
def get_reduce_name(name):
240+
def get_reduce_name(name: str) -> str:
250241
return name + "_redu"
251242

252243

253-
def get_deriv_name(name):
244+
def get_deriv_name(name: str) -> Tuple[str, str]:
254245
return name + "_derv_r", name + "_derv_c"
255246

256247

257248
def do_reduce(
258-
def_outp,
259-
):
260-
def_redu = {}
249+
def_outp: FittingOutputDef,
250+
) -> Dict[str, OutputVariableDef]:
251+
def_redu: Dict[str, OutputVariableDef] = {}
261252
for kk, vv in def_outp.get_data().items():
262253
if vv.reduciable:
263254
rk = get_reduce_name(kk)
264-
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
255+
def_redu[rk] = OutputVariableDef(
256+
rk, vv.shape, reduciable=False, differentiable=False, atomic=False
257+
)
265258
return def_redu
266259

267260

268261
def do_derivative(
269-
def_outp,
270-
):
271-
def_derv_r = {}
272-
def_derv_c = {}
262+
def_outp: FittingOutputDef,
263+
) -> Tuple[Dict[str, OutputVariableDef], Dict[str, OutputVariableDef]]:
264+
def_derv_r: Dict[str, OutputVariableDef] = {}
265+
def_derv_c: Dict[str, OutputVariableDef] = {}
273266
for kk, vv in def_outp.get_data().items():
274267
if vv.differentiable:
275268
rkr, rkc = get_deriv_name(kk)
276-
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
277-
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
269+
def_derv_r[rkr] = OutputVariableDef(
270+
rkr,
271+
vv.shape + [3], # noqa: RUF005
272+
reduciable=False,
273+
differentiable=False,
274+
)
275+
def_derv_c[rkc] = OutputVariableDef(
276+
rkc,
277+
vv.shape + [3, 3], # noqa: RUF005
278+
reduciable=True,
279+
differentiable=False,
280+
)
278281
return def_derv_r, def_derv_c

source/tests/test_output_def.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import unittest
3+
from typing import (
4+
List,
5+
)
36

47
import numpy as np
58

@@ -11,6 +14,21 @@
1114
fitting_check_output,
1215
model_check_output,
1316
)
17+
from deepmd_utils.model_format.output_def import (
18+
check_var,
19+
)
20+
21+
22+
class VariableDef:
23+
def __init__(
24+
self,
25+
name: str,
26+
shape: List[int],
27+
atomic: bool = True,
28+
):
29+
self.name = name
30+
self.shape = list(shape)
31+
self.atomic = atomic
1432

1533

1634
class TestDef(unittest.TestCase):
@@ -81,7 +99,7 @@ def test_model_output_def(self):
8199
self.assertEqual(md["foo"].atomic, True)
82100
self.assertEqual(md["energy_redu"].atomic, False)
83101
self.assertEqual(md["energy_derv_r"].atomic, True)
84-
self.assertEqual(md["energy_derv_c"].atomic, False)
102+
self.assertEqual(md["energy_derv_c"].atomic, True)
85103

86104
def test_raise_no_redu_deriv(self):
87105
with self.assertRaises(ValueError) as context:
@@ -90,6 +108,7 @@ def test_raise_no_redu_deriv(self):
90108
def test_model_decorator(self):
91109
nf = 2
92110
nloc = 3
111+
nall = 4
93112

94113
@model_check_output
95114
class Foo(NativeOP):
@@ -103,8 +122,8 @@ def call(self):
103122
return {
104123
"energy": np.zeros([nf, nloc, 1]),
105124
"energy_redu": np.zeros([nf, 1]),
106-
"energy_derv_r": np.zeros([nf, nloc, 1, 3]),
107-
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
125+
"energy_derv_r": np.zeros([nf, nall, 1, 3]),
126+
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
108127
}
109128

110129
ff = Foo()
@@ -113,6 +132,7 @@ def call(self):
113132
def test_model_decorator_keyerror(self):
114133
nf = 2
115134
nloc = 3
135+
nall = 4
116136

117137
@model_check_output
118138
class Foo(NativeOP):
@@ -129,7 +149,7 @@ def call(self):
129149
return {
130150
"energy": np.zeros([nf, nloc, 1]),
131151
"energy_redu": np.zeros([nf, 1]),
132-
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
152+
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
133153
}
134154

135155
ff = Foo()
@@ -140,13 +160,14 @@ def call(self):
140160
def test_model_decorator_shapeerror(self):
141161
nf = 2
142162
nloc = 3
163+
nall = 4
143164

144165
@model_check_output
145166
class Foo(NativeOP):
146167
def __init__(
147168
self,
148169
shape_rd=[nf, 1],
149-
shape_dr=[nf, nloc, 1, 3],
170+
shape_dr=[nf, nall, 1, 3],
150171
):
151172
self.shape_rd, self.shape_dr = shape_rd, shape_dr
152173

@@ -161,7 +182,7 @@ def call(self):
161182
"energy": np.zeros([nf, nloc, 1]),
162183
"energy_redu": np.zeros(self.shape_rd),
163184
"energy_derv_r": np.zeros(self.shape_dr),
164-
"energy_derv_c": np.zeros([nf, 1, 3, 3]),
185+
"energy_derv_c": np.zeros([nf, nall, 1, 3, 3]),
165186
}
166187

167188
ff = Foo()
@@ -192,6 +213,7 @@ def call(self):
192213
def test_fitting_decorator(self):
193214
nf = 2
194215
nloc = 3
216+
nall = 4
195217

196218
@fitting_check_output
197219
class Foo(NativeOP):
@@ -243,3 +265,40 @@ def call(self):
243265
ff = Foo(shape=[nf, nloc, 2])
244266
ff()
245267
self.assertIn("not matching", context.exception)
268+
269+
def test_check_var(self):
270+
var_def = VariableDef("foo", [2, 3], atomic=True)
271+
with self.assertRaises(ValueError) as context:
272+
check_var(np.zeros([2, 3, 4, 5, 6]), var_def)
273+
self.assertIn("length not matching", context.exception)
274+
with self.assertRaises(ValueError) as context:
275+
check_var(np.zeros([2, 3, 4, 5]), var_def)
276+
self.assertIn("shape not matching", context.exception)
277+
check_var(np.zeros([2, 3, 2, 3]), var_def)
278+
279+
var_def = VariableDef("foo", [2, 3], atomic=False)
280+
with self.assertRaises(ValueError) as context:
281+
check_var(np.zeros([2, 3, 4, 5]), var_def)
282+
self.assertIn("length not matching", context.exception)
283+
with self.assertRaises(ValueError) as context:
284+
check_var(np.zeros([2, 3, 4]), var_def)
285+
self.assertIn("shape not matching", context.exception)
286+
check_var(np.zeros([2, 2, 3]), var_def)
287+
288+
var_def = VariableDef("foo", [2, -1], atomic=True)
289+
with self.assertRaises(ValueError) as context:
290+
check_var(np.zeros([2, 3, 4, 5, 6]), var_def)
291+
self.assertIn("length not matching", context.exception)
292+
with self.assertRaises(ValueError) as context:
293+
check_var(np.zeros([2, 3, 4, 5]), var_def)
294+
self.assertIn("shape not matching", context.exception)
295+
check_var(np.zeros([2, 3, 2, 8]), var_def)
296+
297+
var_def = VariableDef("foo", [2, -1], atomic=False)
298+
with self.assertRaises(ValueError) as context:
299+
check_var(np.zeros([2, 3, 4, 5]), var_def)
300+
self.assertIn("length not matching", context.exception)
301+
with self.assertRaises(ValueError) as context:
302+
check_var(np.zeros([2, 3, 4]), var_def)
303+
self.assertIn("shape not matching", context.exception)
304+
check_var(np.zeros([2, 2, 8]), var_def)

0 commit comments

Comments
 (0)