Skip to content

Commit ff58456

Browse files
wanghan-iapcmHan WangCodiumAI-Agentpre-commit-ci[bot]
authored
add definition for the output of fitting and model (#3128)
Signed-off-by: Han Wang <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: CodiumAI-Agent <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 43f9639 commit ff58456

File tree

3 files changed

+533
-0
lines changed

3 files changed

+533
-0
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
save_dp_model,
1616
traverse_model_dict,
1717
)
18+
from .output_def import (
19+
FittingOutputDef,
20+
ModelOutputDef,
21+
OutputVariableDef,
22+
VariableDef,
23+
fitting_check_output,
24+
model_check_output,
25+
)
1826
from .se_e2_a import (
1927
DescrptSeA,
2028
)
@@ -31,4 +39,10 @@
3139
"traverse_model_dict",
3240
"PRECISION_DICT",
3341
"DEFAULT_PRECISION",
42+
"ModelOutputDef",
43+
"FittingOutputDef",
44+
"OutputVariableDef",
45+
"VariableDef",
46+
"model_check_output",
47+
"fitting_check_output",
3448
]
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Dict,
4+
List,
5+
Tuple,
6+
Union,
7+
)
8+
9+
10+
def check_var(var, var_def):
11+
if var_def.atomic:
12+
# var.shape == [nf, nloc, *var_def.shape]
13+
if len(var.shape) != len(var_def.shape) + 2:
14+
raise ValueError(f"{var.shape[2:]} length not matching def {var_def.shape}")
15+
if list(var.shape[2:]) != var_def.shape:
16+
raise ValueError(f"{var.shape[2:]} not matching def {var_def.shape}")
17+
else:
18+
# var.shape == [nf, *var_def.shape]
19+
if len(var.shape) != len(var_def.shape) + 1:
20+
raise ValueError(f"{var.shape[1:]} length not matching def {var_def.shape}")
21+
if list(var.shape[1:]) != var_def.shape:
22+
raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}")
23+
24+
25+
def model_check_output(cls):
26+
"""Check if the output of the Model is consistent with the definition.
27+
28+
Two methods are assumed to be provided by the Model:
29+
1. Model.output_def that gives the output definition.
30+
2. Model.forward that defines the forward path of the model.
31+
32+
"""
33+
34+
class wrapper(cls):
35+
def __init__(
36+
self,
37+
*args,
38+
**kwargs,
39+
):
40+
super().__init__(*args, **kwargs)
41+
self.md = cls.output_def(self)
42+
43+
def forward(
44+
self,
45+
*args,
46+
**kwargs,
47+
):
48+
ret = cls.forward(self, *args, **kwargs)
49+
for kk in self.md.keys_outp():
50+
dd = self.md[kk]
51+
check_var(ret[kk], dd)
52+
if dd.reduciable:
53+
rk = get_reduce_name(kk)
54+
check_var(ret[rk], self.md[rk])
55+
if dd.differentiable:
56+
dnr, dnc = get_deriv_name(kk)
57+
check_var(ret[dnr], self.md[dnr])
58+
check_var(ret[dnc], self.md[dnc])
59+
return ret
60+
61+
return wrapper
62+
63+
64+
def fitting_check_output(cls):
65+
"""Check if the output of the Fitting is consistent with the definition.
66+
67+
Two methods are assumed to be provided by the Fitting:
68+
1. Fitting.output_def that gives the output definition.
69+
2. Fitting.forward defines the forward path of the fitting.
70+
71+
"""
72+
73+
class wrapper(cls):
74+
def __init__(
75+
self,
76+
*args,
77+
**kwargs,
78+
):
79+
super().__init__(*args, **kwargs)
80+
self.md = cls.output_def(self)
81+
82+
def forward(
83+
self,
84+
*args,
85+
**kwargs,
86+
):
87+
ret = cls.forward(self, *args, **kwargs)
88+
for kk in self.md.keys():
89+
dd = self.md[kk]
90+
check_var(ret[kk], dd)
91+
return ret
92+
93+
return wrapper
94+
95+
96+
class VariableDef:
97+
"""Defines the shape and other properties of a variable.
98+
99+
Parameters
100+
----------
101+
name
102+
Name of the output variable. Notice that the xxxx_redu,
103+
xxxx_derv_c, xxxx_derv_r are reserved names that should
104+
not be used to define variables.
105+
shape
106+
The shape of the variable. e.g. energy should be [1],
107+
dipole should be [3], polarizabilty should be [3,3].
108+
atomic
109+
If the variable is defined for each atom.
110+
111+
"""
112+
113+
def __init__(
114+
self,
115+
name: str,
116+
shape: Union[List[int], Tuple[int]],
117+
atomic: bool = True,
118+
):
119+
self.name = name
120+
self.shape = list(shape)
121+
self.atomic = atomic
122+
123+
124+
class OutputVariableDef(VariableDef):
125+
"""Defines the shape and other properties of the one output variable.
126+
127+
It is assume that the fitting network output variables for each
128+
local atom. This class defines one output variable, including its
129+
name, shape, reducibility and differentiability.
130+
131+
Parameters
132+
----------
133+
name
134+
Name of the output variable. Notice that the xxxx_redu,
135+
xxxx_derv_c, xxxx_derv_r are reserved names that should
136+
not be used to define variables.
137+
shape
138+
The shape of the variable. e.g. energy should be [1],
139+
dipole should be [3], polarizabilty should be [3,3].
140+
reduciable
141+
If the variable is reduced.
142+
differentiable
143+
If the variable is differentiated with respect to coordinates
144+
of atoms and cell tensor (pbc case). Only reduciable variable
145+
are differentiable.
146+
147+
"""
148+
149+
def __init__(
150+
self,
151+
name: str,
152+
shape: Union[List[int], Tuple[int]],
153+
reduciable: bool = False,
154+
differentiable: bool = False,
155+
):
156+
# fitting output must be atomic
157+
super().__init__(name, shape, atomic=True)
158+
self.reduciable = reduciable
159+
self.differentiable = differentiable
160+
if not self.reduciable and self.differentiable:
161+
raise ValueError("only reduciable variable are differentiable")
162+
163+
164+
class FittingOutputDef:
165+
"""Defines the shapes and other properties of the fitting network outputs.
166+
167+
It is assume that the fitting network output variables for each
168+
local atom. This class defines all the outputs.
169+
170+
Parameters
171+
----------
172+
var_defs
173+
List of output variable definitions.
174+
175+
"""
176+
177+
def __init__(
178+
self,
179+
var_defs: List[OutputVariableDef] = [],
180+
):
181+
self.var_defs = {vv.name: vv for vv in var_defs}
182+
183+
def __getitem__(
184+
self,
185+
key,
186+
) -> OutputVariableDef:
187+
return self.var_defs[key]
188+
189+
def get_data(self) -> Dict[str, OutputVariableDef]:
190+
return self.var_defs
191+
192+
def keys(self):
193+
return self.var_defs.keys()
194+
195+
196+
class ModelOutputDef:
197+
"""Defines the shapes and other properties of the model outputs.
198+
199+
The model reduce and differentiate fitting outputs if applicable.
200+
If a variable is named by foo, then the reduced variable is called
201+
foo_redu, the derivative w.r.t. coordinates is called foo_derv_r
202+
and the derivative w.r.t. cell is called foo_derv_c.
203+
204+
Parameters
205+
----------
206+
fit_defs
207+
Definition for the fitting net output
208+
209+
"""
210+
211+
def __init__(
212+
self,
213+
fit_defs: FittingOutputDef,
214+
):
215+
self.def_outp = fit_defs
216+
self.def_redu = do_reduce(self.def_outp)
217+
self.def_derv_r, self.def_derv_c = do_derivative(self.def_outp)
218+
self.var_defs = {}
219+
for ii in [
220+
self.def_outp.get_data(),
221+
self.def_redu,
222+
self.def_derv_c,
223+
self.def_derv_r,
224+
]:
225+
self.var_defs.update(ii)
226+
227+
def __getitem__(self, key) -> VariableDef:
228+
return self.var_defs[key]
229+
230+
def get_data(self, key) -> Dict[str, VariableDef]:
231+
return self.var_defs
232+
233+
def keys(self):
234+
return self.var_defs.keys()
235+
236+
def keys_outp(self):
237+
return self.def_outp.keys()
238+
239+
def keys_redu(self):
240+
return self.def_redu.keys()
241+
242+
def keys_derv_r(self):
243+
return self.def_derv_r.keys()
244+
245+
def keys_derv_c(self):
246+
return self.def_derv_c.keys()
247+
248+
249+
def get_reduce_name(name):
250+
return name + "_redu"
251+
252+
253+
def get_deriv_name(name):
254+
return name + "_derv_r", name + "_derv_c"
255+
256+
257+
def do_reduce(
258+
def_outp,
259+
):
260+
def_redu = {}
261+
for kk, vv in def_outp.get_data().items():
262+
if vv.reduciable:
263+
rk = get_reduce_name(kk)
264+
def_redu[rk] = VariableDef(rk, vv.shape, atomic=False)
265+
return def_redu
266+
267+
268+
def do_derivative(
269+
def_outp,
270+
):
271+
def_derv_r = {}
272+
def_derv_c = {}
273+
for kk, vv in def_outp.get_data().items():
274+
if vv.differentiable:
275+
rkr, rkc = get_deriv_name(kk)
276+
def_derv_r[rkr] = VariableDef(rkr, [*vv.shape, 3], atomic=True)
277+
def_derv_c[rkc] = VariableDef(rkc, [*vv.shape, 3, 3], atomic=False)
278+
return def_derv_r, def_derv_c

0 commit comments

Comments
 (0)