Skip to content

Commit 6a75c6b

Browse files
authored
feat(jax/array-api): se_t_tebd (#4288)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced support for JAX as a backend for the "se_e3_tebd" descriptor, enhancing flexibility in computational options. - Added serialization and deserialization methods to the descriptor classes for better state management. - **Bug Fixes** - Improved handling of attributes in the descriptor classes to ensure correct data types and transformations. - **Tests** - Enhanced the test suite to support multiple backends, including JAX and Array API Strict, improving the robustness of testing. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 8355947 commit 6a75c6b

File tree

5 files changed

+253
-41
lines changed

5 files changed

+253
-41
lines changed

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 111 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,20 @@
55
Union,
66
)
77

8+
import array_api_compat
89
import numpy as np
910

1011
from deepmd.dpmodel import (
1112
PRECISION_DICT,
1213
NativeOP,
1314
)
15+
from deepmd.dpmodel.array_api import (
16+
xp_take_along_axis,
17+
)
18+
from deepmd.dpmodel.common import (
19+
get_xp_precision,
20+
to_numpy_array,
21+
)
1422
from deepmd.dpmodel.utils import (
1523
EmbeddingNet,
1624
EnvMat,
@@ -26,9 +34,6 @@
2634
from deepmd.dpmodel.utils.update_sel import (
2735
UpdateSel,
2836
)
29-
from deepmd.env import (
30-
GLOBAL_NP_FLOAT_PRECISION,
31-
)
3237
from deepmd.utils.data_system import (
3338
DeepmdDataSystem,
3439
)
@@ -318,11 +323,15 @@ def call(
318323
sw
319324
The smooth switch function.
320325
"""
326+
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
321327
del mapping
322328
nf, nloc, nnei = nlist.shape
323-
nall = coord_ext.reshape(nf, -1).shape[1] // 3
329+
nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3
324330
# nf x nall x tebd_dim
325-
atype_embd_ext = self.type_embedding.call()[atype_ext]
331+
atype_embd_ext = xp.reshape(
332+
xp.take(self.type_embedding.call(), xp.reshape(atype_ext, [-1]), axis=0),
333+
(nf, nall, self.tebd_dim),
334+
)
326335
# nfnl x tebd_dim
327336
atype_embd = atype_embd_ext[:, :nloc, :]
328337
grrg, g2, h2, rot_mat, sw = self.se_ttebd(
@@ -334,8 +343,8 @@ def call(
334343
)
335344
# nf x nloc x (ng + tebd_dim)
336345
if self.concat_output_tebd:
337-
grrg = np.concatenate(
338-
[grrg, atype_embd.reshape(nf, nloc, self.tebd_dim)], axis=-1
346+
grrg = xp.concat(
347+
[grrg, xp.reshape(atype_embd, (nf, nloc, self.tebd_dim))], axis=-1
339348
)
340349
return grrg, rot_mat, None, None, sw
341350

@@ -368,8 +377,8 @@ def serialize(self) -> dict:
368377
"env_protection": obj.env_protection,
369378
"smooth": self.smooth,
370379
"@variables": {
371-
"davg": obj["davg"],
372-
"dstd": obj["dstd"],
380+
"davg": to_numpy_array(obj["davg"]),
381+
"dstd": to_numpy_array(obj["dstd"]),
373382
},
374383
"trainable": self.trainable,
375384
}
@@ -491,33 +500,35 @@ def __init__(
491500
else:
492501
self.embd_input_dim = 1
493502

494-
self.embeddings = NetworkCollection(
503+
embeddings = NetworkCollection(
495504
ndim=0,
496505
ntypes=self.ntypes,
497506
network_type="embedding_network",
498507
)
499-
self.embeddings[0] = EmbeddingNet(
508+
embeddings[0] = EmbeddingNet(
500509
self.embd_input_dim,
501510
self.neuron,
502511
self.activation_function,
503512
self.resnet_dt,
504513
self.precision,
505514
seed=child_seed(seed, 0),
506515
)
516+
self.embeddings = embeddings
507517
if self.tebd_input_mode in ["strip"]:
508-
self.embeddings_strip = NetworkCollection(
518+
embeddings_strip = NetworkCollection(
509519
ndim=0,
510520
ntypes=self.ntypes,
511521
network_type="embedding_network",
512522
)
513-
self.embeddings_strip[0] = EmbeddingNet(
523+
embeddings_strip[0] = EmbeddingNet(
514524
self.tebd_dim_input,
515525
self.neuron,
516526
self.activation_function,
517527
self.resnet_dt,
518528
self.precision,
519529
seed=child_seed(seed, 1),
520530
)
531+
self.embeddings_strip = embeddings_strip
521532
else:
522533
self.embeddings_strip = None
523534

@@ -652,82 +663,85 @@ def call(
652663
atype_embd_ext: Optional[np.ndarray] = None,
653664
mapping: Optional[np.ndarray] = None,
654665
):
666+
xp = array_api_compat.array_namespace(nlist, coord_ext, atype_ext)
655667
# nf x nloc x nnei x 4
656668
dmatrix, diff, sw = self.env_mat.call(
657669
coord_ext, atype_ext, nlist, self.mean, self.stddev
658670
)
659671
nf, nloc, nnei, _ = dmatrix.shape
660672
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
661673
# nfnl x nnei
662-
exclude_mask = exclude_mask.reshape(nf * nloc, nnei)
674+
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
663675
# nfnl x nnei
664-
nlist = nlist.reshape(nf * nloc, nnei)
665-
nlist = np.where(exclude_mask, nlist, -1)
676+
nlist = xp.reshape(nlist, (nf * nloc, nnei))
677+
nlist = xp.where(exclude_mask, nlist, xp.full_like(nlist, -1))
666678
# nfnl x nnei
667679
nlist_mask = nlist != -1
668680
# nfnl x nnei x 1
669-
sw = np.where(nlist_mask[:, :, None], sw.reshape(nf * nloc, nnei, 1), 0.0)
681+
sw = xp.where(
682+
nlist_mask[:, :, None],
683+
xp.reshape(sw, (nf * nloc, nnei, 1)),
684+
xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype),
685+
)
670686

671687
# nfnl x nnei x 4
672-
dmatrix = dmatrix.reshape(nf * nloc, nnei, 4)
688+
dmatrix = xp.reshape(dmatrix, (nf * nloc, nnei, 4))
673689
# nfnl x nnei x 4
674690
rr = dmatrix
675-
rr = rr * exclude_mask[:, :, None]
691+
rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype)
676692
# nfnl x nt_i x 3
677693
rr_i = rr[:, :, 1:]
678694
# nfnl x nt_j x 3
679695
rr_j = rr[:, :, 1:]
680696
# nfnl x nt_i x nt_j
681-
env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
697+
# env_ij = np.einsum("ijm,ikm->ijk", rr_i, rr_j)
698+
env_ij = xp.sum(rr_i[:, :, None, :] * rr_j[:, None, :, :], axis=-1)
682699
# nfnl x nt_i x nt_j x 1
683-
ss = np.expand_dims(env_ij, axis=-1)
700+
ss = env_ij[..., None]
684701

685-
nlist_masked = np.where(nlist_mask, nlist, 0)
686-
index = np.tile(nlist_masked.reshape(nf, -1, 1), (1, 1, self.tebd_dim))
702+
nlist_masked = xp.where(nlist_mask, nlist, xp.zeros_like(nlist))
703+
index = xp.tile(xp.reshape(nlist_masked, (nf, -1, 1)), (1, 1, self.tebd_dim))
687704
# nfnl x nnei x tebd_dim
688-
atype_embd_nlist = np.take_along_axis(atype_embd_ext, index, axis=1).reshape(
689-
nf * nloc, nnei, self.tebd_dim
705+
atype_embd_nlist = xp_take_along_axis(atype_embd_ext, index, axis=1)
706+
atype_embd_nlist = xp.reshape(
707+
atype_embd_nlist, (nf * nloc, nnei, self.tebd_dim)
690708
)
691709
# nfnl x nt_i x nt_j x tebd_dim
692-
nlist_tebd_i = np.tile(
693-
np.expand_dims(atype_embd_nlist, axis=2), [1, 1, self.nnei, 1]
694-
)
695-
nlist_tebd_j = np.tile(
696-
np.expand_dims(atype_embd_nlist, axis=1), [1, self.nnei, 1, 1]
697-
)
710+
nlist_tebd_i = xp.tile(atype_embd_nlist[:, :, None, :], (1, 1, self.nnei, 1))
711+
nlist_tebd_j = xp.tile(atype_embd_nlist[:, None, :, :], (1, self.nnei, 1, 1))
698712
ng = self.neuron[-1]
699713

700714
if self.tebd_input_mode in ["concat"]:
701715
# nfnl x nt_i x nt_j x (1 + tebd_dim * 2)
702-
ss = np.concatenate([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
716+
ss = xp.concat([ss, nlist_tebd_i, nlist_tebd_j], axis=-1)
703717
# nfnl x nt_i x nt_j x ng
704718
gg = self.cal_g(ss, 0)
705719
elif self.tebd_input_mode in ["strip"]:
706720
# nfnl x nt_i x nt_j x ng
707721
gg_s = self.cal_g(ss, 0)
708722
assert self.embeddings_strip is not None
709723
# nfnl x nt_i x nt_j x (tebd_dim * 2)
710-
tt = np.concatenate([nlist_tebd_i, nlist_tebd_j], axis=-1)
724+
tt = xp.concat([nlist_tebd_i, nlist_tebd_j], axis=-1)
711725
# nfnl x nt_i x nt_j x ng
712726
gg_t = self.cal_g_strip(tt, 0)
713727
if self.smooth:
714728
gg_t = (
715729
gg_t
716-
* sw.reshape(nf * nloc, self.nnei, 1, 1)
717-
* sw.reshape(nf * nloc, 1, self.nnei, 1)
730+
* xp.reshape(sw, (nf * nloc, self.nnei, 1, 1))
731+
* xp.reshape(sw, (nf * nloc, 1, self.nnei, 1))
718732
)
719733
# nfnl x nt_i x nt_j x ng
720734
gg = gg_s * gg_t + gg_s
721735
else:
722736
raise NotImplementedError
723737

724738
# nfnl x ng
725-
res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
739+
# res_ij = np.einsum("ijk,ijkm->im", env_ij, gg)
740+
res_ij = xp.sum(env_ij[:, :, :, None] * gg[:, :, :, :], axis=(1, 2))
726741
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
727742
# nf x nl x ng
728-
result = res_ij.reshape(nf, nloc, self.filter_neuron[-1]).astype(
729-
GLOBAL_NP_FLOAT_PRECISION
730-
)
743+
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
744+
result = xp.astype(result, get_xp_precision(xp, "global"))
731745
return (
732746
result,
733747
None,
@@ -743,3 +757,61 @@ def has_message_passing(self) -> bool:
743757
def need_sorted_nlist_for_lower(self) -> bool:
744758
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
745759
return False
760+
761+
def serialize(self) -> dict:
762+
"""Serialize the descriptor to dict."""
763+
obj = self
764+
data = {
765+
"@class": "Descriptor",
766+
"type": "se_e3_tebd",
767+
"@version": 1,
768+
"rcut": obj.rcut,
769+
"rcut_smth": obj.rcut_smth,
770+
"sel": obj.sel,
771+
"ntypes": obj.ntypes,
772+
"neuron": obj.neuron,
773+
"tebd_dim": obj.tebd_dim,
774+
"tebd_input_mode": obj.tebd_input_mode,
775+
"set_davg_zero": obj.set_davg_zero,
776+
"activation_function": obj.activation_function,
777+
"resnet_dt": obj.resnet_dt,
778+
# make deterministic
779+
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
780+
"embeddings": obj.embeddings.serialize(),
781+
"env_mat": obj.env_mat.serialize(),
782+
"exclude_types": obj.exclude_types,
783+
"env_protection": obj.env_protection,
784+
"smooth": obj.smooth,
785+
"@variables": {
786+
"davg": to_numpy_array(obj["davg"]),
787+
"dstd": to_numpy_array(obj["dstd"]),
788+
},
789+
}
790+
if obj.tebd_input_mode in ["strip"]:
791+
data.update({"embeddings_strip": obj.embeddings_strip.serialize()})
792+
return data
793+
794+
@classmethod
795+
def deserialize(cls, data: dict) -> "DescrptSeTTebd":
796+
"""Deserialize from dict."""
797+
data = data.copy()
798+
check_version_compatibility(data.pop("@version"), 1, 1)
799+
data.pop("@class")
800+
data.pop("type")
801+
variables = data.pop("@variables")
802+
embeddings = data.pop("embeddings")
803+
env_mat = data.pop("env_mat")
804+
tebd_input_mode = data["tebd_input_mode"]
805+
if tebd_input_mode in ["strip"]:
806+
embeddings_strip = data.pop("embeddings_strip")
807+
else:
808+
embeddings_strip = None
809+
se_ttebd = cls(**data)
810+
811+
se_ttebd["davg"] = variables["davg"]
812+
se_ttebd["dstd"] = variables["dstd"]
813+
se_ttebd.embeddings = NetworkCollection.deserialize(embeddings)
814+
if tebd_input_mode in ["strip"]:
815+
se_ttebd.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
816+
817+
return se_ttebd

deepmd/jax/descriptor/se_t_tebd.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_t_tebd import (
7+
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
8+
)
9+
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
10+
from deepmd.jax.common import (
11+
ArrayAPIVariable,
12+
flax_module,
13+
to_jax_array,
14+
)
15+
from deepmd.jax.descriptor.base_descriptor import (
16+
BaseDescriptor,
17+
)
18+
from deepmd.jax.utils.exclude_mask import (
19+
PairExcludeMask,
20+
)
21+
from deepmd.jax.utils.network import (
22+
NetworkCollection,
23+
)
24+
from deepmd.jax.utils.type_embed import (
25+
TypeEmbedNet,
26+
)
27+
28+
29+
@flax_module
30+
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
31+
def __setattr__(self, name: str, value: Any) -> None:
32+
if name in {"mean", "stddev"}:
33+
value = to_jax_array(value)
34+
if value is not None:
35+
value = ArrayAPIVariable(value)
36+
elif name in {"embeddings", "embeddings_strip"}:
37+
if value is not None:
38+
value = NetworkCollection.deserialize(value.serialize())
39+
elif name == "env_mat":
40+
# env_mat doesn't store any value
41+
pass
42+
elif name == "emask":
43+
value = PairExcludeMask(value.ntypes, value.exclude_types)
44+
45+
return super().__setattr__(name, value)
46+
47+
48+
@BaseDescriptor.register("se_e3_tebd")
49+
@flax_module
50+
class DescrptSeTTebd(DescrptSeTTebdDP):
51+
def __setattr__(self, name: str, value: Any) -> None:
52+
if name == "se_ttebd":
53+
value = DescrptBlockSeTTebd.deserialize(value.serialize())
54+
elif name == "type_embedding":
55+
value = TypeEmbedNet.deserialize(value.serialize())
56+
return super().__setattr__(name, value)

doc/model/train-se-e3-tebd.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ dpmodel_icon }}
1+
# Descriptor `"se_e3_tebd"` {{ pytorch_icon }} {{ jax_icon }} {{ dpmodel_icon }}
22

33
:::{note}
4-
**Supported backends**: PyTorch {{ pytorch_icon }}, DP {{ dpmodel_icon }}
4+
**Supported backends**: PyTorch {{ pytorch_icon }}, JAX {{ jax_icon }}, DP {{ dpmodel_icon }}
55
:::
66

77
The notation of `se_e3_tebd` is short for the three-body embedding descriptor with type embeddings, where the notation `se` denotes the Deep Potential Smooth Edition (DeepPot-SE).
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Any,
4+
)
5+
6+
from deepmd.dpmodel.descriptor.se_t_tebd import (
7+
DescrptBlockSeTTebd as DescrptBlockSeTTebdDP,
8+
)
9+
from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP
10+
11+
from ..common import (
12+
to_array_api_strict_array,
13+
)
14+
from ..utils.exclude_mask import (
15+
PairExcludeMask,
16+
)
17+
from ..utils.network import (
18+
NetworkCollection,
19+
)
20+
from ..utils.type_embed import (
21+
TypeEmbedNet,
22+
)
23+
24+
25+
class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP):
26+
def __setattr__(self, name: str, value: Any) -> None:
27+
if name in {"mean", "stddev"}:
28+
value = to_array_api_strict_array(value)
29+
elif name in {"embeddings", "embeddings_strip"}:
30+
if value is not None:
31+
value = NetworkCollection.deserialize(value.serialize())
32+
elif name == "env_mat":
33+
# env_mat doesn't store any value
34+
pass
35+
elif name == "emask":
36+
value = PairExcludeMask(value.ntypes, value.exclude_types)
37+
38+
return super().__setattr__(name, value)
39+
40+
41+
class DescrptSeTTebd(DescrptSeTTebdDP):
42+
def __setattr__(self, name: str, value: Any) -> None:
43+
if name == "se_ttebd":
44+
value = DescrptBlockSeTTebd.deserialize(value.serialize())
45+
elif name == "type_embedding":
46+
value = TypeEmbedNet.deserialize(value.serialize())
47+
return super().__setattr__(name, value)

0 commit comments

Comments
 (0)