Skip to content

Commit a468819

Browse files
feat(jax/array-api): property fitting (#4287)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced the `PropertyFittingNet` class for enhanced property-specific fitting operations. - Enhanced testing framework to support additional computational backends (JAX and Array API Strict). - **Bug Fixes** - Improved handling of attribute assignments in property fitting. - **Tests** - Added new methods and properties to the testing suite for evaluating property fitting with JAX and Array API Strict. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 704db2f commit a468819

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

deepmd/jax/fitting/fitting.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from deepmd.dpmodel.fitting.polarizability_fitting import (
1010
PolarFitting as PolarFittingNetDP,
1111
)
12+
from deepmd.dpmodel.fitting.property_fitting import (
13+
PropertyFittingNet as PropertyFittingNetDP,
14+
)
1215
from deepmd.jax.common import (
1316
ArrayAPIVariable,
1417
flax_module,
@@ -51,6 +54,14 @@ def __setattr__(self, name: str, value: Any) -> None:
5154
return super().__setattr__(name, value)
5255

5356

57+
@BaseFitting.register("property")
58+
@flax_module
59+
class PropertyFittingNet(PropertyFittingNetDP):
60+
def __setattr__(self, name: str, value: Any) -> None:
61+
value = setattr_for_general_fitting(name, value)
62+
return super().__setattr__(name, value)
63+
64+
5465
@BaseFitting.register("dos")
5566
@flax_module
5667
class DOSFittingNet(DOSFittingNetDP):

source/tests/array_api_strict/fitting/fitting.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from deepmd.dpmodel.fitting.polarizability_fitting import (
1010
PolarFitting as PolarFittingNetDP,
1111
)
12+
from deepmd.dpmodel.fitting.property_fitting import (
13+
PropertyFittingNet as PropertyFittingNetDP,
14+
)
1215

1316
from ..common import (
1417
to_array_api_strict_array,
@@ -43,6 +46,12 @@ def __setattr__(self, name: str, value: Any) -> None:
4346
return super().__setattr__(name, value)
4447

4548

49+
class PropertyFittingNet(PropertyFittingNetDP):
50+
def __setattr__(self, name: str, value: Any) -> None:
51+
value = setattr_for_general_fitting(name, value)
52+
return super().__setattr__(name, value)
53+
54+
4655
class DOSFittingNet(DOSFittingNetDP):
4756
def __setattr__(self, name: str, value: Any) -> None:
4857
value = setattr_for_general_fitting(name, value)

source/tests/consistent/fitting/test_property.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
)
1818

1919
from ..common import (
20+
INSTALLED_ARRAY_API_STRICT,
21+
INSTALLED_JAX,
2022
INSTALLED_PT,
2123
CommonTest,
2224
parameterized,
@@ -32,6 +34,22 @@
3234
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
3335
else:
3436
PropertyFittingPT = object
37+
if INSTALLED_JAX:
38+
from deepmd.jax.env import (
39+
jnp,
40+
)
41+
from deepmd.jax.fitting.fitting import PropertyFittingNet as PropertyFittingJAX
42+
else:
43+
PropertyFittingJAX = object
44+
if INSTALLED_ARRAY_API_STRICT:
45+
import array_api_strict
46+
47+
from ...array_api_strict.fitting.fitting import (
48+
PropertyFittingNet as PropertyFittingStrict,
49+
)
50+
else:
51+
PropertyFittingStrict = object
52+
3553
PropertyFittingTF = object
3654

3755

@@ -84,9 +102,14 @@ def skip_pt(self) -> bool:
84102
def skip_tf(self) -> bool:
85103
return True
86104

105+
skip_jax = not INSTALLED_JAX
106+
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT
107+
87108
tf_class = PropertyFittingTF
88109
dp_class = PropertyFittingDP
89110
pt_class = PropertyFittingPT
111+
jax_class = PropertyFittingJAX
112+
array_api_strict_class = PropertyFittingStrict
90113
args = fitting_property()
91114

92115
def setUp(self):
@@ -183,6 +206,45 @@ def eval_dp(self, dp_obj: Any) -> Any:
183206
aparam=self.aparam if numb_aparam else None,
184207
)["property"]
185208

209+
def eval_jax(self, jax_obj: Any) -> Any:
210+
(
211+
resnet_dt,
212+
precision,
213+
mixed_types,
214+
numb_fparam,
215+
numb_aparam,
216+
task_dim,
217+
intensive,
218+
) = self.param
219+
return np.asarray(
220+
jax_obj(
221+
jnp.asarray(self.inputs),
222+
jnp.asarray(self.atype.reshape(1, -1)),
223+
fparam=jnp.asarray(self.fparam) if numb_fparam else None,
224+
aparam=jnp.asarray(self.aparam) if numb_aparam else None,
225+
)["property"]
226+
)
227+
228+
def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any:
229+
array_api_strict.set_array_api_strict_flags(api_version="2023.12")
230+
(
231+
resnet_dt,
232+
precision,
233+
mixed_types,
234+
numb_fparam,
235+
numb_aparam,
236+
task_dim,
237+
intensive,
238+
) = self.param
239+
return np.asarray(
240+
array_api_strict_obj(
241+
array_api_strict.asarray(self.inputs),
242+
array_api_strict.asarray(self.atype.reshape(1, -1)),
243+
fparam=array_api_strict.asarray(self.fparam) if numb_fparam else None,
244+
aparam=array_api_strict.asarray(self.aparam) if numb_aparam else None,
245+
)["property"]
246+
)
247+
186248
def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
187249
if backend == self.RefBackend.TF:
188250
# shape is not same

0 commit comments

Comments
 (0)