Skip to content

Commit eb2832b

Browse files
authored
feat(jax/array-api): se_e3 (#4286)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new descriptor class `DescrptSeT` for enhanced compatibility with array APIs. - Added support for JAX as a backend option for the `"se_e3"` descriptor. - **Bug Fixes** - Improved array handling in the `clear` method of the `NN` class to ensure compatibility across different array implementations. - **Documentation** - Updated the module exports to include the new `DescrptSeT` class. - Expanded documentation to reflect JAX as a supported backend for the `"se_e3"` descriptor. - **Tests** - Enhanced the test suite to support additional computational backends and added new evaluation methods. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 5c32147 commit eb2832b

File tree

7 files changed

+144
-23
lines changed

7 files changed

+144
-23
lines changed

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
Union,
77
)
88

9+
import array_api_compat
910
import numpy as np
1011

1112
from deepmd.dpmodel import (
1213
DEFAULT_PRECISION,
1314
PRECISION_DICT,
1415
NativeOP,
1516
)
17+
from deepmd.dpmodel.common import (
18+
get_xp_precision,
19+
to_numpy_array,
20+
)
1621
from deepmd.dpmodel.utils import (
1722
EmbeddingNet,
1823
EnvMat,
@@ -25,9 +30,6 @@
2530
from deepmd.dpmodel.utils.update_sel import (
2631
UpdateSel,
2732
)
28-
from deepmd.env import (
29-
GLOBAL_NP_FLOAT_PRECISION,
30-
)
3133
from deepmd.utils.data_system import (
3234
DeepmdDataSystem,
3335
)
@@ -122,26 +124,28 @@ def __init__(
122124
# order matters, placed after the assignment of self.ntypes
123125
self.reinit_exclude(exclude_types)
124126
self.trainable = trainable
127+
self.sel_cumsum = [0, *np.cumsum(self.sel).tolist()]
125128

126129
in_dim = 1 # not considiering type embedding
127-
self.embeddings = NetworkCollection(
130+
embeddings = NetworkCollection(
128131
ntypes=self.ntypes,
129132
ndim=2,
130133
network_type="embedding_network",
131134
)
132135
for ii, embedding_idx in enumerate(
133-
itertools.product(range(self.ntypes), repeat=self.embeddings.ndim)
136+
itertools.product(range(self.ntypes), repeat=embeddings.ndim)
134137
):
135-
self.embeddings[embedding_idx] = EmbeddingNet(
138+
embeddings[embedding_idx] = EmbeddingNet(
136139
in_dim,
137140
self.neuron,
138141
self.activation_function,
139142
self.resnet_dt,
140143
self.precision,
141144
seed=child_seed(self.seed, ii),
142145
)
146+
self.embeddings = embeddings
143147
self.env_mat = EnvMat(self.rcut, self.rcut_smth, protection=self.env_protection)
144-
self.nnei = np.sum(self.sel)
148+
self.nnei = sum(self.sel)
145149
self.davg = np.zeros(
146150
[self.ntypes, self.nnei, 4], dtype=PRECISION_DICT[self.precision]
147151
)
@@ -299,20 +303,22 @@ def call(
299303
The smooth switch function.
300304
"""
301305
del mapping
306+
xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist)
302307
# nf x nloc x nnei x 4
303308
rr, diff, ww = self.env_mat.call(
304309
coord_ext, atype_ext, nlist, self.davg, self.dstd
305310
)
306311
nf, nloc, nnei, _ = rr.shape
307-
sec = np.append([0], np.cumsum(self.sel))
312+
sec = self.sel_cumsum
308313

309314
ng = self.neuron[-1]
310-
result = np.zeros([nf * nloc, ng], dtype=PRECISION_DICT[self.precision])
315+
result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision))
311316
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
312317
# merge nf and nloc axis, so for type_one_side == False,
313318
# we don't require atype is the same in all frames
314-
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
315-
rr = rr.reshape(nf * nloc, nnei, 4)
319+
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
320+
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
321+
rr = xp.astype(rr, get_xp_precision(xp, self.precision))
316322

317323
for embedding_idx in itertools.product(
318324
range(self.ntypes), repeat=self.embeddings.ndim
@@ -325,23 +331,26 @@ def call(
325331
# nfnl x nt_i x 3
326332
rr_i = rr[:, sec[ti] : sec[ti + 1], 1:]
327333
mm_i = exclude_mask[:, sec[ti] : sec[ti + 1]]
328-
rr_i = rr_i * mm_i[:, :, None]
334+
rr_i = rr_i * xp.astype(mm_i[:, :, None], rr_i.dtype)
329335
# nfnl x nt_j x 3
330336
rr_j = rr[:, sec[tj] : sec[tj + 1], 1:]
331337
mm_j = exclude_mask[:, sec[tj] : sec[tj + 1]]
332-
rr_j = rr_j * mm_j[:, :, None]
338+
rr_j = rr_j * xp.astype(mm_j[:, :, None], rr_j.dtype)
333339
# nfnl x nt_i x nt_j
334-
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
340+
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
341+
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
335342
# nfnl x nt_i x nt_j x 1
336343
env_ij_reshape = env_ij[:, :, :, None]
337344
# nfnl x nt_i x nt_j x ng
338345
gg = self.embeddings[embedding_idx].call(env_ij_reshape)
339346
# nfnl x nt_i x nt_j x ng
340-
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
347+
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
348+
res_ij = xp.sum(env_ij[:, :, :, None] * gg, axis=(1, 2))
341349
res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j))
342350
result += res_ij
343351
# nf x nloc x ng
344-
result = result.reshape(nf, nloc, ng).astype(GLOBAL_NP_FLOAT_PRECISION)
352+
result = xp.reshape(result, (nf, nloc, ng))
353+
result = xp.astype(result, get_xp_precision(xp, "global"))
345354
return result, None, None, None, ww
346355

347356
def serialize(self) -> dict:
@@ -369,8 +378,8 @@ def serialize(self) -> dict:
369378
"exclude_types": self.exclude_types,
370379
"env_protection": self.env_protection,
371380
"@variables": {
372-
"davg": self.davg,
373-
"dstd": self.dstd,
381+
"davg": to_numpy_array(self.davg),
382+
"dstd": to_numpy_array(self.dstd),
374383
},
375384
"type_map": self.type_map,
376385
"trainable": self.trainable,

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,11 +572,12 @@ def call(self, x):
572572
def clear(self):
573573
"""Clear the network parameters to zero."""
574574
for layer in self.layers:
575-
layer.w.fill(0.0)
575+
xp = array_api_compat.array_namespace(layer.w)
576+
layer.w = xp.zeros_like(layer.w)
576577
if layer.b is not None:
577-
layer.b.fill(0.0)
578+
layer.b = xp.zeros_like(layer.b)
578579
if layer.idt is not None:
579-
layer.idt.fill(0.0)
580+
layer.idt = xp.zeros_like(layer.idt)
580581

581582
return NN
582583

deepmd/jax/descriptor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
from deepmd.jax.descriptor.se_e2_r import (
1212
DescrptSeR,
1313
)
14+
from deepmd.jax.descriptor.se_t import (
15+
DescrptSeT,
16+
)
1417

1518
__all__ = [
1619
"DescrptSeA",
1720
"DescrptSeR",
21+
"DescrptSeT",
1822
"DescrptDPA1",
1923
"DescrptHybrid",
2024
]

deepmd/jax/descriptor/se_t.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
7+
from deepmd.jax.common import (
8+
ArrayAPIVariable,
9+
flax_module,
10+
to_jax_array,
11+
)
12+
from deepmd.jax.descriptor.base_descriptor import (
13+
BaseDescriptor,
14+
)
15+
from deepmd.jax.utils.exclude_mask import (
16+
PairExcludeMask,
17+
)
18+
from deepmd.jax.utils.network import (
19+
NetworkCollection,
20+
)
21+
22+
23+
@BaseDescriptor.register("se_e3")
24+
@BaseDescriptor.register("se_at")
25+
@BaseDescriptor.register("se_a_3be")
26+
@flax_module
27+
class DescrptSeT(DescrptSeTDP):
28+
def __setattr__(self, name: str, value: Any) -> None:
29+
if name in {"dstd", "davg"}:
30+
value = to_jax_array(value)
31+
if value is not None:
32+
value = ArrayAPIVariable(value)
33+
elif name in {"embeddings"}:
34+
if value is not None:
35+
value = NetworkCollection.deserialize(value.serialize())
36+
elif name == "env_mat":
37+
# env_mat doesn't store any value
38+
pass
39+
elif name == "emask":
40+
value = PairExcludeMask(value.ntypes, value.exclude_types)
41+
42+
return super().__setattr__(name, value)

doc/model/train-se-e3.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Descriptor `"se_e3"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Descriptor `"se_e3"` {{ tensorflow_icon }} {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
The notation of `se_e3` is short for three-body embedding DeepPot-SE, which incorporates embedded bond-angle information.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP
7+
8+
from ..common import (
9+
to_array_api_strict_array,
10+
)
11+
from ..utils.exclude_mask import (
12+
PairExcludeMask,
13+
)
14+
from ..utils.network import (
15+
NetworkCollection,
16+
)
17+
18+
19+
class DescrptSeT(DescrptSeTDP):
20+
def __setattr__(self, name: str, value: Any) -> None:
21+
if name in {"dstd", "davg"}:
22+
value = to_array_api_strict_array(value)
23+
elif name in {"embeddings"}:
24+
if value is not None:
25+
value = NetworkCollection.deserialize(value.serialize())
26+
elif name == "env_mat":
27+
# env_mat doesn't store any value
28+
pass
29+
elif name == "emask":
30+
value = PairExcludeMask(value.ntypes, value.exclude_types)
31+
32+
return super().__setattr__(name, value)

source/tests/consistent/descriptor/test_se_t.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313

1414
from ..common import (
15+
INSTALLED_ARRAY_API_STRICT,
16+
INSTALLED_JAX,
1517
INSTALLED_PT,
1618
INSTALLED_TF,
1719
CommonTest,
@@ -29,6 +31,14 @@
2931
from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF
3032
else:
3133
DescrptSeTTF = None
34+
if INSTALLED_JAX:
35+
from deepmd.jax.descriptor.se_t import DescrptSeT as DescrptSeTJAX
36+
else:
37+
DescrptSeTJAX = None
38+
if INSTALLED_ARRAY_API_STRICT:
39+
from ...array_api_strict.descriptor.se_t import DescrptSeT as DescrptSeTStrict
40+
else:
41+
DescrptSeTStrict = None
3242
from deepmd.utils.argcheck import (
3343
descrpt_se_t_args,
3444
)
@@ -91,9 +101,14 @@ def skip_tf(self) -> bool:
91101
) = self.param
92102
return env_protection != 0.0 or excluded_types
93103

104+
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
105+
skip_jax = not INSTALLED_JAX
106+
94107
tf_class = DescrptSeTTF
95108
dp_class = DescrptSeTDP
96109
pt_class = DescrptSeTPT
110+
jax_class = DescrptSeTJAX
111+
array_api_strict_class = DescrptSeTStrict
97112
args = descrpt_se_t_args()
98113

99114
def setUp(self):
@@ -168,6 +183,24 @@ def eval_pt(self, pt_obj: Any) -> Any:
168183
self.box,
169184
)
170185

186+
def eval_jax(self, jax_obj: Any) -> Any:
187+
return self.eval_jax_descriptor(
188+
jax_obj,
189+
self.natoms,
190+
self.coords,
191+
self.atype,
192+
self.box,
193+
)
194+
195+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
196+
return self.eval_array_api_strict_descriptor(
197+
array_api_strict_obj,
198+
self.natoms,
199+
self.coords,
200+
self.atype,
201+
self.box,
202+
)
203+
171204
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
172205
return (ret[0],)
173206

0 commit comments

Comments
 (0)