Skip to content

Commit e1c868e

Browse files
authored
feat(jax/array-api): se_atten_v2 (#4289)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new descriptor class `DescrptSeAttenV2`, enhancing the framework's capabilities. - Improved test suite flexibility to support multiple backends, including JAX and strict array API. - **Bug Fixes** - Updated serialization logic for `davg` and `dstd` to ensure consistent output format. - **Tests** - Enhanced `TestSeAttenV2` class with new properties and methods for better backend evaluation. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 6a75c6b commit e1c868e

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

deepmd/dpmodel/descriptor/se_atten_v2.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
DEFAULT_PRECISION,
1212
PRECISION_DICT,
1313
)
14+
from deepmd.dpmodel.common import (
15+
to_numpy_array,
16+
)
1417
from deepmd.dpmodel.utils import (
1518
NetworkCollection,
1619
)
@@ -146,8 +149,8 @@ def serialize(self) -> dict:
146149
"exclude_types": obj.exclude_types,
147150
"env_protection": obj.env_protection,
148151
"@variables": {
149-
"davg": obj["davg"],
150-
"dstd": obj["dstd"],
152+
"davg": to_numpy_array(obj["davg"]),
153+
"dstd": to_numpy_array(obj["dstd"]),
151154
},
152155
## to be updated when the options are supported.
153156
"trainable": self.trainable,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
3+
from deepmd.jax.descriptor.base_descriptor import (
4+
BaseDescriptor,
5+
)
6+
from deepmd.jax.descriptor.dpa1 import (
7+
DescrptDPA1,
8+
)
9+
10+
11+
@BaseDescriptor.register("se_atten_v2")
12+
class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
13+
pass
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP
3+
4+
from .dpa1 import (
5+
DescrptDPA1,
6+
)
7+
8+
9+
class DescrptSeAttenV2(DescrptDPA1, DescrptSeAttenV2DP):
10+
pass

source/tests/consistent/descriptor/test_se_atten_v2.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
)
1717

1818
from ..common import (
19+
INSTALLED_ARRAY_API_STRICT,
20+
INSTALLED_JAX,
1921
INSTALLED_PT,
2022
CommonTest,
2123
parameterized,
@@ -30,6 +32,18 @@
3032
)
3133
else:
3234
DescrptSeAttenV2PT = None
35+
if INSTALLED_JAX:
36+
from deepmd.jax.descriptor.se_atten_v2 import (
37+
DescrptSeAttenV2 as DescrptSeAttenV2JAX,
38+
)
39+
else:
40+
DescrptSeAttenV2JAX = None
41+
if INSTALLED_ARRAY_API_STRICT:
42+
from ...array_api_strict.descriptor.se_atten_v2 import (
43+
DescrptSeAttenV2 as DescrptSeAttenV2Strict,
44+
)
45+
else:
46+
DescrptSeAttenV2Strict = None
3347
DescrptSeAttenV2TF = None
3448
from deepmd.utils.argcheck import (
3549
descrpt_se_atten_args,
@@ -175,9 +189,70 @@ def skip_dp(self) -> bool:
175189
def skip_tf(self) -> bool:
176190
return True
177191

192+
@property
193+
def skip_jax(self) -> bool:
194+
(
195+
tebd_dim,
196+
resnet_dt,
197+
type_one_side,
198+
attn,
199+
attn_layer,
200+
attn_dotr,
201+
excluded_types,
202+
env_protection,
203+
set_davg_zero,
204+
scaling_factor,
205+
normalize,
206+
temperature,
207+
ln_eps,
208+
concat_output_tebd,
209+
precision,
210+
use_econf_tebd,
211+
use_tebd_bias,
212+
) = self.param
213+
return not INSTALLED_JAX or self.is_meaningless_zero_attention_layer_tests(
214+
attn_layer,
215+
attn_dotr,
216+
normalize,
217+
temperature,
218+
)
219+
220+
@property
221+
def skip_array_api_strict(self) -> bool:
222+
(
223+
tebd_dim,
224+
resnet_dt,
225+
type_one_side,
226+
attn,
227+
attn_layer,
228+
attn_dotr,
229+
excluded_types,
230+
env_protection,
231+
set_davg_zero,
232+
scaling_factor,
233+
normalize,
234+
temperature,
235+
ln_eps,
236+
concat_output_tebd,
237+
precision,
238+
use_econf_tebd,
239+
use_tebd_bias,
240+
) = self.param
241+
return (
242+
not INSTALLED_ARRAY_API_STRICT
243+
or self.is_meaningless_zero_attention_layer_tests(
244+
attn_layer,
245+
attn_dotr,
246+
normalize,
247+
temperature,
248+
)
249+
)
250+
178251
tf_class = DescrptSeAttenV2TF
179252
dp_class = DescrptSeAttenV2DP
180253
pt_class = DescrptSeAttenV2PT
254+
jax_class = DescrptSeAttenV2JAX
255+
array_api_strict_class = DescrptSeAttenV2Strict
181256
args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False))
182257

183258
def setUp(self):
@@ -244,6 +319,26 @@ def eval_pt(self, pt_obj: Any) -> Any:
244319
mixed_types=True,
245320
)
246321

322+
def eval_jax(self, jax_obj: Any) -> Any:
323+
return self.eval_jax_descriptor(
324+
jax_obj,
325+
self.natoms,
326+
self.coords,
327+
self.atype,
328+
self.box,
329+
mixed_types=True,
330+
)
331+
332+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
333+
return self.eval_array_api_strict_descriptor(
334+
array_api_strict_obj,
335+
self.natoms,
336+
self.coords,
337+
self.atype,
338+
self.box,
339+
mixed_types=True,
340+
)
341+
247342
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
248343
return (ret[0],)
249344

0 commit comments

Comments
 (0)