Skip to content

Commit 308f97e

Browse files
wanghan-iapcmHan Wang
andauthored
support fitting net (#3137)
- also add doc string for the embedding net --------- Co-authored-by: Han Wang <[email protected]>
1 parent d5590a4 commit 308f97e

File tree

3 files changed

+143
-3
lines changed

3 files changed

+143
-3
lines changed

deepmd_utils/model_format/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
)
1010
from .network import (
1111
EmbeddingNet,
12+
FittingNet,
1213
NativeLayer,
1314
NativeNet,
1415
NetworkCollection,
@@ -32,6 +33,7 @@
3233
"DescrptSeA",
3334
"EnvMat",
3435
"EmbeddingNet",
36+
"FittingNet",
3537
"NativeLayer",
3638
"NativeNet",
3739
"NetworkCollection",

deepmd_utils/model_format/network.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def __init__(
162162
self.w = w.astype(prec) if w is not None else None
163163
self.b = b.astype(prec) if b is not None else None
164164
self.idt = idt.astype(prec) if idt is not None else None
165-
self.activation_function = activation_function
165+
self.activation_function = (
166+
activation_function if activation_function is not None else "none"
167+
)
166168
self.resnet = resnet
167169
self.check_type_consistency()
168170

@@ -354,6 +356,24 @@ def call(self, x: np.ndarray) -> np.ndarray:
354356

355357

356358
class EmbeddingNet(NativeNet):
359+
"""The embedding network.
360+
361+
Parameters
362+
----------
363+
in_dim
364+
Input dimension.
365+
neuron
366+
The number of neurons in each layer. The output dimension
367+
is the same as the dimension of the last layer.
368+
activation_function
369+
The activation function.
370+
resnet_dt
371+
Use time step at the resnet architecture.
372+
precision
373+
Floating point precision for the model paramters.
374+
375+
"""
376+
357377
def __init__(
358378
self,
359379
in_dim,
@@ -370,8 +390,8 @@ def __init__(
370390
layers.append(
371391
NativeLayer(
372392
rng.normal(size=(i_in, i_ot)),
373-
b=rng.normal(size=(ii)),
374-
idt=rng.normal(size=(ii)) if resnet_dt else None,
393+
b=rng.normal(size=(i_ot)),
394+
idt=rng.normal(size=(i_ot)) if resnet_dt else None,
375395
activation_function=activation_function,
376396
resnet=True,
377397
precision=precision,
@@ -417,6 +437,95 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
417437
return obj
418438

419439

440+
class FittingNet(EmbeddingNet):
441+
"""The fitting network. It may be implemented as an embedding
442+
net connected with a linear output layer.
443+
444+
Parameters
445+
----------
446+
in_dim
447+
Input dimension.
448+
out_dim
449+
Output dimension
450+
neuron
451+
The number of neurons in each hidden layer.
452+
activation_function
453+
The activation function.
454+
resnet_dt
455+
Use time step at the resnet architecture.
456+
precision
457+
Floating point precision for the model paramters.
458+
bias_out
459+
The last linear layer has bias.
460+
461+
"""
462+
463+
def __init__(
464+
self,
465+
in_dim,
466+
out_dim,
467+
neuron: List[int] = [24, 48, 96],
468+
activation_function: str = "tanh",
469+
resnet_dt: bool = False,
470+
precision: str = DEFAULT_PRECISION,
471+
bias_out: bool = True,
472+
):
473+
super().__init__(
474+
in_dim,
475+
neuron=neuron,
476+
activation_function=activation_function,
477+
resnet_dt=resnet_dt,
478+
precision=precision,
479+
)
480+
rng = np.random.default_rng()
481+
i_in, i_ot = neuron[-1], out_dim
482+
self.layers.append(
483+
NativeLayer(
484+
rng.normal(size=(i_in, i_ot)),
485+
b=rng.normal(size=(i_ot)) if bias_out else None,
486+
idt=None,
487+
activation_function=None,
488+
resnet=False,
489+
precision=precision,
490+
)
491+
)
492+
self.out_dim = out_dim
493+
self.bias_out = bias_out
494+
495+
def serialize(self) -> dict:
496+
"""Serialize the network to a dict.
497+
498+
Returns
499+
-------
500+
dict
501+
The serialized network.
502+
"""
503+
return {
504+
"in_dim": self.in_dim,
505+
"out_dim": self.out_dim,
506+
"neuron": self.neuron.copy(),
507+
"activation_function": self.activation_function,
508+
"resnet_dt": self.resnet_dt,
509+
"precision": self.precision,
510+
"bias_out": self.bias_out,
511+
"layers": [layer.serialize() for layer in self.layers],
512+
}
513+
514+
@classmethod
515+
def deserialize(cls, data: dict) -> "FittingNet":
516+
"""Deserialize the network from a dict.
517+
518+
Parameters
519+
----------
520+
data : dict
521+
The dict to deserialize from.
522+
"""
523+
layers = data.pop("layers")
524+
obj = cls(**data)
525+
NativeNet.__init__(obj, layers)
526+
return obj
527+
528+
420529
class NetworkCollection:
421530
"""A collection of networks for multiple elements.
422531
@@ -439,6 +548,7 @@ class NetworkCollection:
439548
NETWORK_TYPE_MAP: ClassVar[Dict[str, type]] = {
440549
"network": NativeNet,
441550
"embedding_network": EmbeddingNet,
551+
"fitting_network": FittingNet,
442552
}
443553

444554
def __init__(

source/tests/test_model_format_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DescrptSeA,
1313
EmbeddingNet,
1414
EnvMat,
15+
FittingNet,
1516
NativeLayer,
1617
NativeNet,
1718
NetworkCollection,
@@ -98,6 +99,8 @@ def test_deserialize(self):
9899
np.testing.assert_array_equal(network[0]["resnet"], True)
99100
np.testing.assert_array_equal(network[1]["resnet"], True)
100101

102+
103+
class TestEmbeddingNet(unittest.TestCase):
101104
def test_embedding_net(self):
102105
for ni, act, idt, prec in itertools.product(
103106
[1, 10],
@@ -116,6 +119,31 @@ def test_embedding_net(self):
116119
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
117120

118121

122+
class TestFittingNet(unittest.TestCase):
123+
def test_fitting_net(self):
124+
for ni, no, act, idt, prec, bo in itertools.product(
125+
[1, 10],
126+
[1, 7],
127+
["tanh", "none"],
128+
[True, False],
129+
["double", "single"],
130+
[True, False],
131+
):
132+
en0 = FittingNet(
133+
ni,
134+
no,
135+
activation_function=act,
136+
precision=prec,
137+
resnet_dt=idt,
138+
bias_out=bo,
139+
)
140+
en1 = FittingNet.deserialize(en0.serialize())
141+
inp = np.ones([ni])
142+
en0.call(inp)
143+
en1.call(inp)
144+
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
145+
146+
119147
class TestNetworkCollection(unittest.TestCase):
120148
def setUp(self) -> None:
121149
w = np.full((2, 3), 3.0)

0 commit comments

Comments
 (0)