Skip to content

Commit 43f9639

Browse files
authored
add native Networks for mutiple Network classes (#3117)
Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 398f037 commit 43f9639

File tree

4 files changed

+194
-12
lines changed

4 files changed

+194
-12
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
EmbeddingNet,
1111
NativeLayer,
1212
NativeNet,
13+
NetworkCollection,
1314
load_dp_model,
1415
save_dp_model,
1516
traverse_model_dict,
@@ -24,6 +25,7 @@
2425
"EmbeddingNet",
2526
"NativeLayer",
2627
"NativeNet",
28+
"NetworkCollection",
2729
"load_dp_model",
2830
"save_dp_model",
2931
"traverse_model_dict",

deepmd_utils/model_format/network.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
44
See issue #2982 for more information.
55
"""
6+
import itertools
67
import json
78
from typing import (
9+
ClassVar,
10+
Dict,
811
List,
912
Optional,
13+
Union,
1014
)
1115

1216
import h5py
@@ -411,3 +415,111 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
411415
obj = cls(**data)
412416
super(EmbeddingNet, obj).__init__(layers)
413417
return obj
418+
419+
420+
class NetworkCollection:
421+
"""A collection of networks for multiple elements.
422+
423+
The number of dimesions for types might be 0, 1, or 2.
424+
- 0: embedding or fitting with type embedding, in ()
425+
- 1: embedding with type_one_side, or fitting, in (type_i)
426+
- 2: embedding without type_one_side, in (type_i, type_j)
427+
428+
Parameters
429+
----------
430+
ndim : int
431+
The number of dimensions.
432+
network_type : str, optional
433+
The type of the network.
434+
networks : dict, optional
435+
The networks to initialize with.
436+
"""
437+
438+
# subclass may override this
439+
NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = {
440+
"network": NativeNet,
441+
"embedding_network": EmbeddingNet,
442+
}
443+
444+
def __init__(
445+
self,
446+
ndim: int,
447+
ntypes: int,
448+
network_type: str = "network",
449+
networks: List[Union[NativeNet, dict]] = [],
450+
):
451+
self.ndim = ndim
452+
self.ntypes = ntypes
453+
self.network_type = self.NETWORK_TYPE_MAP[network_type]
454+
self._networks = [None for ii in range(ntypes**ndim)]
455+
for ii, network in enumerate(networks):
456+
self[ii] = network
457+
if len(networks):
458+
self.check_completeness()
459+
460+
def check_completeness(self):
461+
"""Check whether the collection is complete.
462+
463+
Raises
464+
------
465+
RuntimeError
466+
If the collection is incomplete.
467+
"""
468+
for tt in itertools.product(range(self.ntypes), repeat=self.ndim):
469+
if self[tuple(tt)] is None:
470+
raise RuntimeError(f"network for {tt} not found")
471+
472+
def _convert_key(self, key):
473+
if isinstance(key, int):
474+
idx = key
475+
else:
476+
if isinstance(key, tuple):
477+
pass
478+
elif isinstance(key, str):
479+
key = tuple([int(tt) for tt in key.split("_")[1:]])
480+
else:
481+
raise TypeError(key)
482+
assert isinstance(key, tuple)
483+
assert len(key) == self.ndim
484+
idx = sum([tt * self.ntypes**ii for ii, tt in enumerate(key)])
485+
return idx
486+
487+
def __getitem__(self, key):
488+
return self._networks[self._convert_key(key)]
489+
490+
def __setitem__(self, key, value):
491+
if isinstance(value, self.network_type):
492+
pass
493+
elif isinstance(value, dict):
494+
value = self.network_type.deserialize(value)
495+
else:
496+
raise TypeError(value)
497+
self._networks[self._convert_key(key)] = value
498+
499+
def serialize(self) -> dict:
500+
"""Serialize the networks to a dict.
501+
502+
Returns
503+
-------
504+
dict
505+
The serialized networks.
506+
"""
507+
network_type_map_inv = {v: k for k, v in self.NETWORK_TYPE_MAP.items()}
508+
network_type_name = network_type_map_inv[self.network_type]
509+
return {
510+
"ndim": self.ndim,
511+
"ntypes": self.ntypes,
512+
"network_type": network_type_name,
513+
"networks": [nn.serialize() for nn in self._networks],
514+
}
515+
516+
@classmethod
517+
def deserialize(cls, data: dict) -> "NetworkCollection":
518+
"""Deserialize the networks from a dict.
519+
520+
Parameters
521+
----------
522+
data : dict
523+
The dict to deserialize from.
524+
"""
525+
return cls(**data)

deepmd_utils/model_format/se_e2_a.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from .network import (
2323
EmbeddingNet,
24+
NetworkCollection,
2425
)
2526

2627

@@ -154,16 +155,18 @@ def __init__(
154155
self.spin = spin
155156

156157
in_dim = 1 # not considiering type embedding
157-
self.embeddings = []
158+
self.embeddings = NetworkCollection(
159+
ntypes=self.ntypes,
160+
ndim=(1 if self.type_one_side else 2),
161+
network_type="embedding_network",
162+
)
158163
for ii in range(self.ntypes):
159-
self.embeddings.append(
160-
EmbeddingNet(
161-
in_dim,
162-
self.neuron,
163-
self.activation_function,
164-
self.resnet_dt,
165-
self.precision,
166-
)
164+
self.embeddings[(ii,)] = EmbeddingNet(
165+
in_dim,
166+
self.neuron,
167+
self.activation_function,
168+
self.resnet_dt,
169+
self.precision,
167170
)
168171
self.env_mat = EnvMat(self.rcut, self.rcut_smth)
169172
self.nnei = np.sum(self.sel)
@@ -196,7 +199,7 @@ def cal_g(
196199
nf, nloc, nnei = ss.shape[0:3]
197200
ss = ss.reshape(nf, nloc, nnei, 1)
198201
# nf x nloc x nnei x ng
199-
gg = self.embeddings[ll].call(ss)
202+
gg = self.embeddings[(ll,)].call(ss)
200203
return gg
201204

202205
def call(
@@ -258,7 +261,7 @@ def serialize(self) -> dict:
258261
"precision": self.precision,
259262
"spin": self.spin,
260263
"env_mat": self.env_mat.serialize(),
261-
"embeddings": [ii.serialize() for ii in self.embeddings],
264+
"embeddings": self.embeddings.serialize(),
262265
"@variables": {
263266
"davg": self.davg,
264267
"dstd": self.dstd,
@@ -274,6 +277,6 @@ def deserialize(cls, data: dict) -> "DescrptSeA":
274277

275278
obj["davg"] = variables["davg"]
276279
obj["dstd"] = variables["dstd"]
277-
obj.embeddings = [EmbeddingNet.deserialize(dd) for dd in embeddings]
280+
obj.embeddings = NetworkCollection.deserialize(embeddings)
278281
obj.env_mat = EnvMat.deserialize(env_mat)
279282
return obj

source/tests/test_model_format_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
EnvMat,
1515
NativeLayer,
1616
NativeNet,
17+
NetworkCollection,
1718
load_dp_model,
1819
save_dp_model,
1920
)
@@ -115,6 +116,70 @@ def test_embedding_net(self):
115116
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
116117

117118

119+
class TestNetworkCollection(unittest.TestCase):
120+
def setUp(self) -> None:
121+
w = np.full((2, 3), 3.0)
122+
b = np.full((3,), 4.0)
123+
self.network = {
124+
"layers": [
125+
{
126+
"activation_function": "tanh",
127+
"resnet": True,
128+
"@variables": {"w": w, "b": b},
129+
},
130+
{
131+
"activation_function": "tanh",
132+
"resnet": True,
133+
"@variables": {"w": w, "b": b},
134+
},
135+
],
136+
}
137+
138+
def test_two_dim(self):
139+
networks = NetworkCollection(ndim=2, ntypes=2)
140+
networks[(0, 0)] = self.network
141+
networks[(1, 1)] = self.network
142+
networks[(0, 1)] = self.network
143+
with self.assertRaises(RuntimeError):
144+
networks.check_completeness()
145+
networks[(1, 0)] = self.network
146+
networks.check_completeness()
147+
np.testing.assert_equal(
148+
networks.serialize(),
149+
NetworkCollection.deserialize(networks.serialize()).serialize(),
150+
)
151+
np.testing.assert_equal(
152+
networks[(0, 0)].serialize(), networks.serialize()["networks"][0]
153+
)
154+
155+
def test_one_dim(self):
156+
networks = NetworkCollection(ndim=1, ntypes=2)
157+
networks[(0,)] = self.network
158+
with self.assertRaises(RuntimeError):
159+
networks.check_completeness()
160+
networks[(1,)] = self.network
161+
networks.check_completeness()
162+
np.testing.assert_equal(
163+
networks.serialize(),
164+
NetworkCollection.deserialize(networks.serialize()).serialize(),
165+
)
166+
np.testing.assert_equal(
167+
networks[(0,)].serialize(), networks.serialize()["networks"][0]
168+
)
169+
170+
def test_zero_dim(self):
171+
networks = NetworkCollection(ndim=0, ntypes=2)
172+
networks[()] = self.network
173+
networks.check_completeness()
174+
np.testing.assert_equal(
175+
networks.serialize(),
176+
NetworkCollection.deserialize(networks.serialize()).serialize(),
177+
)
178+
np.testing.assert_equal(
179+
networks[()].serialize(), networks.serialize()["networks"][0]
180+
)
181+
182+
118183
class TestDPModel(unittest.TestCase):
119184
def setUp(self) -> None:
120185
self.w = np.full((3, 2), 3.0)

0 commit comments

Comments
 (0)