Skip to content

Commit 15117a0

Browse files
wanghan-iapcmHan Wang
andauthored
refactorize NativeLayer, interface does not rely on the platform (#3138)
- add parameter shape consistency check for layer - add input-output shape consistency check for net Co-authored-by: Han Wang <[email protected]>
1 parent 308f97e commit 15117a0

File tree

3 files changed

+154
-53
lines changed

3 files changed

+154
-53
lines changed

deepmd_utils/model_format/network.py

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
44
See issue #2982 for more information.
55
"""
6+
import copy
67
import itertools
78
import json
89
from typing import (
@@ -150,23 +151,26 @@ class NativeLayer(NativeOP):
150151

151152
def __init__(
152153
self,
153-
w: Optional[np.ndarray] = None,
154-
b: Optional[np.ndarray] = None,
155-
idt: Optional[np.ndarray] = None,
154+
num_in,
155+
num_out,
156+
bias: bool = True,
157+
use_timestep: bool = False,
156158
activation_function: Optional[str] = None,
157159
resnet: bool = False,
158160
precision: str = DEFAULT_PRECISION,
159161
) -> None:
160162
prec = PRECISION_DICT[precision.lower()]
161163
self.precision = precision
162-
self.w = w.astype(prec) if w is not None else None
163-
self.b = b.astype(prec) if b is not None else None
164-
self.idt = idt.astype(prec) if idt is not None else None
164+
rng = np.random.default_rng()
165+
self.w = rng.normal(size=(num_in, num_out)).astype(prec)
166+
self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None
167+
self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None
165168
self.activation_function = (
166169
activation_function if activation_function is not None else "none"
167170
)
168171
self.resnet = resnet
169172
self.check_type_consistency()
173+
self.check_shape_consistency()
170174

171175
def serialize(self) -> dict:
172176
"""Serialize the layer to a dict.
@@ -179,10 +183,11 @@ def serialize(self) -> dict:
179183
data = {
180184
"w": self.w,
181185
"b": self.b,
186+
"idt": self.idt,
182187
}
183-
if self.idt is not None:
184-
data["idt"] = self.idt
185188
return {
189+
"bias": self.b is not None,
190+
"use_timestep": self.idt is not None,
186191
"activation_function": self.activation_function,
187192
"resnet": self.resnet,
188193
"precision": self.precision,
@@ -198,15 +203,34 @@ def deserialize(cls, data: dict) -> "NativeLayer":
198203
data : dict
199204
The dict to deserialize from.
200205
"""
201-
precision = data.get("precision", DEFAULT_PRECISION)
202-
return cls(
203-
w=data["@variables"]["w"],
204-
b=data["@variables"].get("b", None),
205-
idt=data["@variables"].get("idt", None),
206-
activation_function=data["activation_function"],
207-
resnet=data.get("resnet", False),
208-
precision=precision,
206+
data = copy.deepcopy(data)
207+
variables = data.pop("@variables")
208+
assert variables["w"] is not None and len(variables["w"].shape) == 2
209+
num_in, num_out = variables["w"].shape
210+
obj = cls(
211+
num_in,
212+
num_out,
213+
**data,
209214
)
215+
obj.w, obj.b, obj.idt = (
216+
variables["w"],
217+
variables.get("b", None),
218+
variables.get("idt", None),
219+
)
220+
obj.check_shape_consistency()
221+
return obj
222+
223+
def check_shape_consistency(self):
224+
if self.b is not None and self.w.shape[1] != self.b.shape[0]:
225+
raise ValueError(
226+
f"dim 1 of w {self.w.shape[1]} is not equal to shape "
227+
f"of b {self.b.shape[0]}",
228+
)
229+
if self.idt is not None and self.w.shape[1] != self.idt.shape[0]:
230+
raise ValueError(
231+
f"dim 1 of w {self.w.shape[1]} is not equal to shape "
232+
f"of idt {self.idt.shape[0]}",
233+
)
210234

211235
def check_type_consistency(self):
212236
precision = self.precision
@@ -252,6 +276,14 @@ def __getitem__(self, key):
252276
else:
253277
raise KeyError(key)
254278

279+
@property
280+
def dim_in(self) -> int:
281+
return self.w.shape[0]
282+
283+
@property
284+
def dim_out(self) -> int:
285+
return self.w.shape[1]
286+
255287
def call(self, x: np.ndarray) -> np.ndarray:
256288
"""Forward pass.
257289
@@ -303,6 +335,7 @@ def __init__(self, layers: Optional[List[dict]] = None) -> None:
303335
if layers is None:
304336
layers = []
305337
self.layers = [NativeLayer.deserialize(layer) for layer in layers]
338+
self.check_shape_consistency()
306339

307340
def serialize(self) -> dict:
308341
"""Serialize the network to a dict.
@@ -327,16 +360,21 @@ def deserialize(cls, data: dict) -> "NativeNet":
327360

328361
def __getitem__(self, key):
329362
assert isinstance(key, int)
330-
if len(self.layers) <= key:
331-
self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1))
332363
return self.layers[key]
333364

334365
def __setitem__(self, key, value):
335366
assert isinstance(key, int)
336-
if len(self.layers) <= key:
337-
self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1))
338367
self.layers[key] = value
339368

369+
def check_shape_consistency(self):
370+
for ii in range(len(self.layers) - 1):
371+
if self.layers[ii].dim_out != self.layers[ii + 1].dim_in:
372+
raise ValueError(
373+
f"the dim of layer {ii} output {self.layers[ii].dim_out} ",
374+
f"does not match the dim of layer {ii+1} ",
375+
f"output {self.layers[ii].dim_out}",
376+
)
377+
340378
def call(self, x: np.ndarray) -> np.ndarray:
341379
"""Forward pass.
342380
@@ -389,9 +427,10 @@ def __init__(
389427
i_ot = ii
390428
layers.append(
391429
NativeLayer(
392-
rng.normal(size=(i_in, i_ot)),
393-
b=rng.normal(size=(i_ot)),
394-
idt=rng.normal(size=(i_ot)) if resnet_dt else None,
430+
i_in,
431+
i_ot,
432+
bias=True,
433+
use_timestep=resnet_dt,
395434
activation_function=activation_function,
396435
resnet=True,
397436
precision=precision,
@@ -431,6 +470,7 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
431470
data : dict
432471
The dict to deserialize from.
433472
"""
473+
data = copy.deepcopy(data)
434474
layers = data.pop("layers")
435475
obj = cls(**data)
436476
super(EmbeddingNet, obj).__init__(layers)
@@ -481,9 +521,10 @@ def __init__(
481521
i_in, i_ot = neuron[-1], out_dim
482522
self.layers.append(
483523
NativeLayer(
484-
rng.normal(size=(i_in, i_ot)),
485-
b=rng.normal(size=(i_ot)) if bias_out else None,
486-
idt=None,
524+
i_in,
525+
i_ot,
526+
bias=bias_out,
527+
use_timestep=False,
487528
activation_function=None,
488529
resnet=False,
489530
precision=precision,
@@ -520,6 +561,7 @@ def deserialize(cls, data: dict) -> "FittingNet":
520561
data : dict
521562
The dict to deserialize from.
522563
"""
564+
data = copy.deepcopy(data)
523565
layers = data.pop("layers")
524566
obj = cls(**data)
525567
NativeNet.__init__(obj, layers)

deepmd_utils/model_format/se_e2_a.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
except ImportError:
77
__version__ = "unknown"
88

9+
import copy
910
from typing import (
1011
Any,
1112
List,
@@ -270,6 +271,7 @@ def serialize(self) -> dict:
270271

271272
@classmethod
272273
def deserialize(cls, data: dict) -> "DescrptSeA":
274+
data = copy.deepcopy(data)
273275
variables = data.pop("@variables")
274276
embeddings = data.pop("embeddings")
275277
env_mat = data.pop("env_mat")

source/tests/test_model_format_utils.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,39 +35,74 @@ def test_serialize_deserize(self):
3535
[None, [4], [3, 2]],
3636
["float32", "float64", "single", "double"],
3737
):
38-
ww = np.full((ni, no), 3.0)
39-
bb = np.full((no,), 4.0) if bias else None
40-
idt = np.full((no,), 5.0) if ut else None
41-
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet, prec)
38+
nl0 = NativeLayer(
39+
ni,
40+
no,
41+
bias=bias,
42+
use_timestep=ut,
43+
activation_function=activation_function,
44+
resnet=resnet,
45+
precision=prec,
46+
)
4247
nl1 = NativeLayer.deserialize(nl0.serialize())
43-
inp_shap = [ww.shape[0]]
48+
inp_shap = [ni]
4449
if ashp is not None:
4550
inp_shap = ashp + inp_shap
4651
inp = np.arange(np.prod(inp_shap)).reshape(inp_shap)
4752
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))
4853

54+
def test_shape_error(self):
55+
self.w0 = np.full((2, 3), 3.0)
56+
self.b0 = np.full((2,), 4.0)
57+
self.b1 = np.full((3,), 4.0)
58+
self.idt0 = np.full((2,), 4.0)
59+
with self.assertRaises(ValueError) as context:
60+
network = NativeLayer.deserialize(
61+
{
62+
"activation_function": "tanh",
63+
"resnet": True,
64+
"@variables": {"w": self.w0, "b": self.b0},
65+
}
66+
)
67+
assert "not equalt to shape of b" in context.exception
68+
with self.assertRaises(ValueError) as context:
69+
network = NativeLayer.deserialize(
70+
{
71+
"activation_function": "tanh",
72+
"resnet": True,
73+
"@variables": {"w": self.w0, "b": self.b1, "idt": self.idt0},
74+
}
75+
)
76+
assert "not equalt to shape of idt" in context.exception
77+
4978

5079
class TestNativeNet(unittest.TestCase):
5180
def setUp(self) -> None:
52-
self.w = np.full((2, 3), 3.0)
53-
self.b = np.full((3,), 4.0)
54-
self.idt = np.full((3,), 5.0)
81+
self.w0 = np.full((2, 3), 3.0)
82+
self.b0 = np.full((3,), 4.0)
83+
self.w1 = np.full((3, 4), 3.0)
84+
self.b1 = np.full((4,), 4.0)
5585

5686
def test_serialize(self):
57-
network = NativeNet()
58-
network[1]["w"] = self.w
59-
network[1]["b"] = self.b
60-
network[0]["w"] = self.w
61-
network[0]["b"] = self.b
87+
network = NativeNet(
88+
[
89+
NativeLayer(2, 3).serialize(),
90+
NativeLayer(3, 4).serialize(),
91+
]
92+
)
93+
network[1]["w"] = self.w1
94+
network[1]["b"] = self.b1
95+
network[0]["w"] = self.w0
96+
network[0]["b"] = self.b0
6297
network[1]["activation_function"] = "tanh"
6398
network[0]["activation_function"] = "tanh"
6499
network[1]["resnet"] = True
65100
network[0]["resnet"] = True
66101
jdata = network.serialize()
67-
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w)
68-
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b)
69-
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w)
70-
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b)
102+
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w0)
103+
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b0)
104+
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w1)
105+
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b1)
71106
np.testing.assert_array_equal(jdata["layers"][0]["activation_function"], "tanh")
72107
np.testing.assert_array_equal(jdata["layers"][1]["activation_function"], "tanh")
73108
np.testing.assert_array_equal(jdata["layers"][0]["resnet"], True)
@@ -80,25 +115,45 @@ def test_deserialize(self):
80115
{
81116
"activation_function": "tanh",
82117
"resnet": True,
83-
"@variables": {"w": self.w, "b": self.b},
118+
"@variables": {"w": self.w0, "b": self.b0},
84119
},
85120
{
86121
"activation_function": "tanh",
87122
"resnet": True,
88-
"@variables": {"w": self.w, "b": self.b},
123+
"@variables": {"w": self.w1, "b": self.b1},
89124
},
90125
],
91126
}
92127
)
93-
np.testing.assert_array_equal(network[0]["w"], self.w)
94-
np.testing.assert_array_equal(network[0]["b"], self.b)
95-
np.testing.assert_array_equal(network[1]["w"], self.w)
96-
np.testing.assert_array_equal(network[1]["b"], self.b)
128+
np.testing.assert_array_equal(network[0]["w"], self.w0)
129+
np.testing.assert_array_equal(network[0]["b"], self.b0)
130+
np.testing.assert_array_equal(network[1]["w"], self.w1)
131+
np.testing.assert_array_equal(network[1]["b"], self.b1)
97132
np.testing.assert_array_equal(network[0]["activation_function"], "tanh")
98133
np.testing.assert_array_equal(network[1]["activation_function"], "tanh")
99134
np.testing.assert_array_equal(network[0]["resnet"], True)
100135
np.testing.assert_array_equal(network[1]["resnet"], True)
101136

137+
def test_shape_error(self):
138+
with self.assertRaises(ValueError) as context:
139+
network = NativeNet.deserialize(
140+
{
141+
"layers": [
142+
{
143+
"activation_function": "tanh",
144+
"resnet": True,
145+
"@variables": {"w": self.w0, "b": self.b0},
146+
},
147+
{
148+
"activation_function": "tanh",
149+
"resnet": True,
150+
"@variables": {"w": self.w0, "b": self.b0},
151+
},
152+
],
153+
}
154+
)
155+
assert "does not match the dim of layer" in context.exception
156+
102157

103158
class TestEmbeddingNet(unittest.TestCase):
104159
def test_embedding_net(self):
@@ -146,19 +201,21 @@ def test_fitting_net(self):
146201

147202
class TestNetworkCollection(unittest.TestCase):
148203
def setUp(self) -> None:
149-
w = np.full((2, 3), 3.0)
150-
b = np.full((3,), 4.0)
204+
w0 = np.full((2, 3), 3.0)
205+
b0 = np.full((3,), 4.0)
206+
w1 = np.full((3, 4), 3.0)
207+
b1 = np.full((4,), 4.0)
151208
self.network = {
152209
"layers": [
153210
{
154211
"activation_function": "tanh",
155212
"resnet": True,
156-
"@variables": {"w": w, "b": b},
213+
"@variables": {"w": w0, "b": b0},
157214
},
158215
{
159216
"activation_function": "tanh",
160217
"resnet": True,
161-
"@variables": {"w": w, "b": b},
218+
"@variables": {"w": w1, "b": b1},
162219
},
163220
],
164221
}

0 commit comments

Comments
 (0)