Skip to content

Commit 438bc78

Browse files
wanghan-iapcmHan Wang
andauthored
Add dp model format sea (#3123)
- add precision test for embedding net Limitations - only support `type_one_side` - does not support type embedding and `stripped_type_embedding` - does not support `exclude_types` - does not support spin --------- Co-authored-by: Han Wang <[email protected]>
1 parent a971d92 commit 438bc78

File tree

4 files changed

+243
-5
lines changed

4 files changed

+243
-5
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
from .common import (
3+
DEFAULT_PRECISION,
34
PRECISION_DICT,
45
)
56
from .env_mat import (
@@ -13,8 +14,12 @@
1314
save_dp_model,
1415
traverse_model_dict,
1516
)
17+
from .se_e2_a import (
18+
DescrptSeA,
19+
)
1620

1721
__all__ = [
22+
"DescrptSeA",
1823
"EnvMat",
1924
"EmbeddingNet",
2025
"NativeLayer",
@@ -23,4 +28,5 @@
2328
"save_dp_model",
2429
"traverse_model_dict",
2530
"PRECISION_DICT",
31+
"DEFAULT_PRECISION",
2632
]

deepmd_utils/model_format/network.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ def __init__(
379379
self.neuron = neuron
380380
self.activation_function = activation_function
381381
self.resnet_dt = resnet_dt
382+
self.precision = precision
382383

383384
def serialize(self) -> dict:
384385
"""Serialize the network to a dict.
@@ -393,6 +394,7 @@ def serialize(self) -> dict:
393394
"neuron": self.neuron.copy(),
394395
"activation_function": self.activation_function,
395396
"resnet_dt": self.resnet_dt,
397+
"precision": self.precision,
396398
"layers": [layer.serialize() for layer in self.layers],
397399
}
398400

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
4+
try:
5+
from deepmd_utils._version import version as __version__
6+
except ImportError:
7+
__version__ = "unknown"
8+
9+
from typing import (
10+
Any,
11+
List,
12+
Optional,
13+
)
14+
15+
from .common import (
16+
DEFAULT_PRECISION,
17+
NativeOP,
18+
)
19+
from .env_mat import (
20+
EnvMat,
21+
)
22+
from .network import (
23+
EmbeddingNet,
24+
)
25+
26+
27+
class DescrptSeA(NativeOP):
28+
def __init__(
29+
self,
30+
rcut: float,
31+
rcut_smth: float,
32+
sel: List[str],
33+
neuron: List[int] = [24, 48, 96],
34+
axis_neuron: int = 8,
35+
resnet_dt: bool = False,
36+
trainable: bool = True,
37+
type_one_side: bool = True,
38+
exclude_types: List[List[int]] = [],
39+
set_davg_zero: bool = False,
40+
activation_function: str = "tanh",
41+
precision: str = DEFAULT_PRECISION,
42+
spin: Optional[Any] = None,
43+
stripped_type_embedding: bool = False,
44+
) -> None:
45+
## seed, uniform_seed, multi_task, not included.
46+
if not type_one_side:
47+
raise NotImplementedError("type_one_side == False not implemented")
48+
if stripped_type_embedding:
49+
raise NotImplementedError("stripped_type_embedding is not implemented")
50+
if exclude_types != []:
51+
raise NotImplementedError("exclude_types is not implemented")
52+
if spin is not None:
53+
raise NotImplementedError("spin is not implemented")
54+
55+
self.rcut = rcut
56+
self.rcut_smth = rcut_smth
57+
self.sel = sel
58+
self.ntypes = len(self.sel)
59+
self.neuron = neuron
60+
self.axis_neuron = axis_neuron
61+
self.resnet_dt = resnet_dt
62+
self.trainable = trainable
63+
self.type_one_side = type_one_side
64+
self.exclude_types = exclude_types
65+
self.set_davg_zero = set_davg_zero
66+
self.activation_function = activation_function
67+
self.precision = precision
68+
self.spin = spin
69+
self.stripped_type_embedding = stripped_type_embedding
70+
71+
in_dim = 1 # not considiering type embedding
72+
self.embeddings = []
73+
for ii in range(self.ntypes):
74+
self.embeddings.append(
75+
EmbeddingNet(
76+
in_dim,
77+
self.neuron,
78+
self.activation_function,
79+
self.resnet_dt,
80+
self.precision,
81+
)
82+
)
83+
self.env_mat = EnvMat(self.rcut, self.rcut_smth)
84+
self.nnei = np.sum(self.sel)
85+
self.nneix4 = self.nnei * 4
86+
self.davg = np.zeros([self.ntypes, self.nneix4])
87+
self.dstd = np.ones([self.ntypes, self.nneix4])
88+
self.orig_sel = self.sel
89+
90+
def __setitem__(self, key, value):
91+
if key in ("avg", "data_avg", "davg"):
92+
self.davg = value
93+
elif key in ("std", "data_std", "dstd"):
94+
self.dstd = value
95+
else:
96+
raise KeyError(key)
97+
98+
def __getitem__(self, key):
99+
if key in ("avg", "data_avg", "davg"):
100+
return self.davg
101+
elif key in ("std", "data_std", "dstd"):
102+
return self.dstd
103+
else:
104+
raise KeyError(key)
105+
106+
def cal_g(
107+
self,
108+
ss,
109+
ll,
110+
):
111+
nf, nloc, nnei = ss.shape[0:3]
112+
ss = ss.reshape(nf, nloc, nnei, 1)
113+
# nf x nloc x nnei x ng
114+
gg = self.embeddings[ll].call(ss)
115+
return gg
116+
117+
def call(
118+
self,
119+
coord_ext,
120+
atype_ext,
121+
nlist,
122+
):
123+
"""Compute the environment matrix.
124+
125+
Parameters
126+
----------
127+
coord_ext
128+
The extended coordinates of atoms. shape: nf x (nallx3)
129+
atype_ext
130+
The extended aotm types. shape: nf x nall
131+
nlist
132+
The neighbor list. shape: nf x nloc x nnei
133+
134+
Returns
135+
-------
136+
descriptor
137+
The descriptor. shape: nf x nloc x ng x axis_neuron
138+
"""
139+
# nf x nloc x nnei x 4
140+
rr, ww = self.env_mat.call(nlist, coord_ext, atype_ext, self.davg, self.dstd)
141+
nf, nloc, nnei, _ = rr.shape
142+
sec = np.append([0], np.cumsum(self.sel))
143+
144+
ng = self.neuron[-1]
145+
gr = np.zeros([nf, nloc, ng, 4])
146+
for tt in range(self.ntypes):
147+
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
148+
ss = tr[..., 0:1]
149+
gg = self.cal_g(ss, tt)
150+
# nf x nloc x ng x 4
151+
gr += np.einsum("flni,flnj->flij", gg, tr)
152+
gr /= self.nnei
153+
gr1 = gr[:, :, : self.axis_neuron, :]
154+
# nf x nloc x ng x ng1
155+
grrg = np.einsum("flid,fljd->flij", gr, gr1)
156+
# nf x nloc x (ng x ng1)
157+
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
158+
return grrg
159+
160+
def serialize(self) -> dict:
161+
return {
162+
"rcut": self.rcut,
163+
"rcut_smth": self.rcut_smth,
164+
"sel": self.sel,
165+
"neuron": self.neuron,
166+
"axis_neuron": self.axis_neuron,
167+
"resnet_dt": self.resnet_dt,
168+
"trainable": self.trainable,
169+
"type_one_side": self.type_one_side,
170+
"exclude_types": self.exclude_types,
171+
"set_davg_zero": self.set_davg_zero,
172+
"activation_function": self.activation_function,
173+
"precision": self.precision,
174+
"spin": self.spin,
175+
"stripped_type_embedding": self.stripped_type_embedding,
176+
"env_mat": self.env_mat.serialize(),
177+
"embeddings": [ii.serialize() for ii in self.embeddings],
178+
"@variables": {
179+
"davg": self.davg,
180+
"dstd": self.dstd,
181+
},
182+
}
183+
184+
@classmethod
185+
def deserialize(cls, data: dict) -> "DescrptSeA":
186+
variables = data.pop("@variables")
187+
embeddings = data.pop("embeddings")
188+
env_mat = data.pop("env_mat")
189+
obj = cls(**data)
190+
191+
obj["davg"] = variables["davg"]
192+
obj["dstd"] = variables["dstd"]
193+
obj.embeddings = [EmbeddingNet.deserialize(dd) for dd in embeddings]
194+
obj.env_mat = EnvMat.deserialize(env_mat)
195+
return obj

source/tests/test_model_format_utils.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010

1111
from deepmd_utils.model_format import (
12+
DescrptSeA,
1213
EmbeddingNet,
1314
EnvMat,
1415
NativeLayer,
@@ -97,12 +98,18 @@ def test_deserialize(self):
9798
np.testing.assert_array_equal(network[1]["resnet"], True)
9899

99100
def test_embedding_net(self):
100-
for ni, idt, act in itertools.product(
101+
for ni, act, idt, prec in itertools.product(
101102
[1, 10],
102-
[True, False],
103103
["tanh", "none"],
104+
[True, False],
105+
["double", "single"],
104106
):
105-
en0 = EmbeddingNet(ni)
107+
en0 = EmbeddingNet(
108+
ni,
109+
activation_function=act,
110+
precision=prec,
111+
resnet_dt=idt,
112+
)
106113
en1 = EmbeddingNet.deserialize(en0.serialize())
107114
inp = np.ones([ni])
108115
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
@@ -141,7 +148,7 @@ def tearDown(self) -> None:
141148
os.remove(self.filename)
142149

143150

144-
class TestEnvMat(unittest.TestCase):
151+
class TestCaseSingleFrameWithNlist:
145152
def setUp(self):
146153
# nloc == 3, nall == 4
147154
self.nloc = 3
@@ -158,17 +165,23 @@ def setUp(self):
158165
).reshape([1, self.nall * 3])
159166
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
160167
# sel = [5, 2]
168+
self.sel = [5, 2]
161169
self.nlist = np.array(
162170
[
163171
[1, 3, -1, -1, -1, 2, -1],
164172
[0, -1, -1, -1, -1, 2, -1],
165173
[0, 1, -1, -1, -1, 0, -1],
166174
],
167175
dtype=int,
168-
).reshape([1, self.nloc, 7])
176+
).reshape([1, self.nloc, sum(self.sel)])
169177
self.rcut = 0.4
170178
self.rcut_smth = 2.2
171179

180+
181+
class TestEnvMat(unittest.TestCase, TestCaseSingleFrameWithNlist):
182+
def setUp(self):
183+
TestCaseSingleFrameWithNlist.setUp(self)
184+
172185
def test_self_consistency(
173186
self,
174187
):
@@ -183,3 +196,25 @@ def test_self_consistency(
183196
mm1, ww1 = em1.call(self.nlist, self.coord_ext, self.atype_ext, davg, dstd)
184197
np.testing.assert_allclose(mm0, mm1)
185198
np.testing.assert_allclose(ww0, ww1)
199+
200+
201+
class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist):
202+
def setUp(self):
203+
TestCaseSingleFrameWithNlist.setUp(self)
204+
205+
def test_self_consistency(
206+
self,
207+
):
208+
rng = np.random.default_rng()
209+
nf, nloc, nnei = self.nlist.shape
210+
davg = rng.normal(size=(self.nt, nnei, 4))
211+
dstd = rng.normal(size=(self.nt, nnei, 4))
212+
dstd = 0.1 + np.abs(dstd)
213+
214+
em0 = DescrptSeA(self.rcut, self.rcut_smth, self.sel)
215+
em0.davg = davg
216+
em0.dstd = dstd
217+
em1 = DescrptSeA.deserialize(em0.serialize())
218+
mm0 = em0.call(self.coord_ext, self.atype_ext, self.nlist)
219+
mm1 = em1.call(self.coord_ext, self.atype_ext, self.nlist)
220+
np.testing.assert_allclose(mm0, mm1)

0 commit comments

Comments
 (0)