Skip to content

Commit fb41a4f

Browse files
feat(jax): atomic virial (#4290)
For the frozen model, store two exported functions: one enables do_atomic_virial and the other doesn't. This PR is in conflict with #4285 (in `serialization.py`), and the conflict must be resolved after one is merged. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a new parameter for enhanced atomic virial data handling in model evaluations. - Added support for atomic virial calculations in multiple model evaluation methods. - Updated export functionality to dynamically include atomic virial data based on user input. - **Bug Fixes** - Improved output structures across various backends to accommodate new atomic virial data. - **Tests** - Enhanced test cases to verify the new atomic virial functionalities and ensure compatibility with existing evaluations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e1c868e commit fb41a4f

File tree

7 files changed

+131
-22
lines changed

7 files changed

+131
-22
lines changed

deepmd/jax/infer/deep_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def __init__(
9393
model_data = load_dp_model(model_file)
9494
self.dp = HLO(
9595
stablehlo=model_data["@variables"]["stablehlo"].tobytes(),
96+
stablehlo_atomic_virial=model_data["@variables"][
97+
"stablehlo_atomic_virial"
98+
].tobytes(),
9699
model_def_script=model_data["model_def_script"],
97100
**model_data["constants"],
98101
)

deepmd/jax/model/base_model.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,65 @@ def eval_output(
9191
assert vdef.r_differentiable
9292
# avr: [nf, *def, nall, 3, 3]
9393
avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord)
94+
# the correction sums to zero, which does not contribute to global virial
95+
if do_atomic_virial:
96+
97+
def eval_ce(
98+
cc_ext,
99+
extended_atype,
100+
nlist,
101+
mapping,
102+
fparam,
103+
aparam,
104+
*,
105+
_kk=kk,
106+
_atom_axis=atom_axis - 1,
107+
):
108+
# atomic_ret[_kk]: [nf, nloc, *def]
109+
atomic_ret = self.atomic_model.forward_common_atomic(
110+
cc_ext[None, ...],
111+
extended_atype[None, ...],
112+
nlist[None, ...],
113+
mapping=mapping[None, ...] if mapping is not None else None,
114+
fparam=fparam[None, ...] if fparam is not None else None,
115+
aparam=aparam[None, ...] if aparam is not None else None,
116+
)
117+
nloc = nlist.shape[0]
118+
cc_loc = jax.lax.stop_gradient(cc_ext)[:nloc, ...]
119+
cc_loc = jnp.reshape(cc_loc, [nloc, *[1] * def_ndim, 3])
120+
# [*def, 3]
121+
return jnp.sum(
122+
atomic_ret[_kk][0, ..., None] * cc_loc, axis=_atom_axis
123+
)
124+
125+
# extended_virial_corr: [nf, *def, 3, nall, 3]
126+
extended_virial_corr = jax.vmap(jax.jacrev(eval_ce, argnums=0))(
127+
extended_coord,
128+
extended_atype,
129+
nlist,
130+
mapping,
131+
fparam,
132+
aparam,
133+
)
134+
# move the first 3 to the last
135+
# [nf, *def, nall, 3, 3]
136+
extended_virial_corr = jnp.transpose(
137+
extended_virial_corr,
138+
[
139+
0,
140+
*range(1, def_ndim + 1),
141+
def_ndim + 2,
142+
def_ndim + 3,
143+
def_ndim + 1,
144+
],
145+
)
146+
avr += extended_virial_corr
147+
# to [...,3,3] -> [...,9]
94148
# avr: [nf, *def, nall, 9]
95149
avr = jnp.reshape(avr, [*ff.shape[:-1], 9])
96150
# extended_virial: [nf, nall, *def, 9]
97151
extended_virial = jnp.transpose(
98152
avr, [0, def_ndim + 1, *range(1, def_ndim + 1), def_ndim + 2]
99153
)
100-
101-
# the correction sums to zero, which does not contribute to global virial
102-
# cannot jit
103-
# if do_atomic_virial:
104-
# raise NotImplementedError("Atomic virial is not implemented yet.")
105-
# to [...,3,3] -> [...,9]
106154
model_predict[kk_derv_c] = extended_virial
107155
return model_predict

deepmd/jax/model/hlo.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class HLO(BaseModel):
4545
def __init__(
4646
self,
4747
stablehlo,
48+
stablehlo_atomic_virial,
4849
model_def_script,
4950
type_map,
5051
rcut,
@@ -58,6 +59,9 @@ def __init__(
5859
sel,
5960
) -> None:
6061
self._call_lower = jax_export.deserialize(stablehlo).call
62+
self._call_lower_atomic_virial = jax_export.deserialize(
63+
stablehlo_atomic_virial
64+
).call
6165
self.stablehlo = stablehlo
6266
self.type_map = type_map
6367
self.rcut = rcut
@@ -170,14 +174,17 @@ def call_lower(
170174
aparam: Optional[jnp.ndarray] = None,
171175
do_atomic_virial: bool = False,
172176
):
173-
return self._call_lower(
177+
if do_atomic_virial:
178+
call_lower = self._call_lower_atomic_virial
179+
else:
180+
call_lower = self._call_lower
181+
return call_lower(
174182
extended_coord,
175183
extended_atype,
176184
nlist,
177185
mapping,
178186
fparam,
179187
aparam,
180-
do_atomic_virial,
181188
)
182189

183190
def get_type_map(self) -> list[str]:

deepmd/jax/utils/serialization.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,48 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
5252
call_lower = model.call_lower
5353

5454
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")
55-
exported = jax_export.export(jax.jit(call_lower))(
56-
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
57-
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
58-
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
59-
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
60-
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
61-
if model.get_dim_fparam()
62-
else None, # fparam
63-
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
64-
if model.get_dim_aparam()
65-
else None, # aparam
66-
False, # do_atomic_virial
55+
56+
def exported_whether_do_atomic_virial(do_atomic_virial):
57+
def call_lower_with_fixed_do_atomic_virial(
58+
coord, atype, nlist, nlist_start, fparam, aparam
59+
):
60+
return call_lower(
61+
coord,
62+
atype,
63+
nlist,
64+
nlist_start,
65+
fparam,
66+
aparam,
67+
do_atomic_virial=do_atomic_virial,
68+
)
69+
70+
return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))(
71+
jax.ShapeDtypeStruct(
72+
(nf, nloc + nghost, 3), jnp.float64
73+
), # extended_coord
74+
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
75+
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
76+
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
77+
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
78+
if model.get_dim_fparam()
79+
else None, # fparam
80+
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
81+
if model.get_dim_aparam()
82+
else None, # aparam
83+
)
84+
85+
exported = exported_whether_do_atomic_virial(do_atomic_virial=False)
86+
exported_atomic_virial = exported_whether_do_atomic_virial(
87+
do_atomic_virial=True
6788
)
6889
serialized: bytearray = exported.serialize()
90+
serialized_atomic_virial = exported_atomic_virial.serialize()
6991
data = data.copy()
7092
data.setdefault("@variables", {})
7193
data["@variables"]["stablehlo"] = np.void(serialized)
94+
data["@variables"]["stablehlo_atomic_virial"] = np.void(
95+
serialized_atomic_virial
96+
)
7297
data["constants"] = {
7398
"type_map": model.get_type_map(),
7499
"rcut": model.get_rcut(),

source/tests/consistent/io/test_io.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,15 @@ def test_deep_eval(self):
163163
aparam=aparam,
164164
)
165165
rets.append(ret)
166+
ret = deep_eval.eval(
167+
self.coords,
168+
self.box,
169+
self.atype,
170+
fparam=fparam,
171+
aparam=aparam,
172+
do_atomic_virial=True,
173+
)
174+
rets.append(ret)
166175
for ret in rets[1:]:
167176
for vv1, vv2 in zip(rets[0], ret):
168177
if np.isnan(vv2).all():

source/tests/consistent/model/common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ def build_tf_model(self, obj, natoms, coords, atype, box, suffix):
5151
{},
5252
suffix=suffix,
5353
)
54-
return [ret["energy"], ret["atom_ener"], ret["force"], ret["virial"]], {
54+
return [
55+
ret["energy"],
56+
ret["atom_ener"],
57+
ret["force"],
58+
ret["virial"],
59+
ret["atom_virial"],
60+
], {
5561
t_coord: coords,
5662
t_type: atype,
5763
t_natoms: natoms,
@@ -69,6 +75,7 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any:
6975
numpy_to_torch(coords),
7076
numpy_to_torch(atype),
7177
box=numpy_to_torch(box),
78+
do_atomic_virial=True,
7279
).items()
7380
}
7481

@@ -83,5 +90,6 @@ def assert_jax_array(arr):
8390
numpy_to_jax(coords),
8491
numpy_to_jax(atype),
8592
box=numpy_to_jax(box),
93+
do_atomic_virial=True,
8694
).items()
8795
}

source/tests/consistent/model/test_ener.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,21 +216,30 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
216216
ret["energy"].ravel(),
217217
SKIP_FLAG,
218218
SKIP_FLAG,
219+
SKIP_FLAG,
219220
)
220221
elif backend is self.RefBackend.PT:
221222
return (
222223
ret["energy"].ravel(),
223224
ret["atom_energy"].ravel(),
224225
ret["force"].ravel(),
225226
ret["virial"].ravel(),
227+
ret["atom_virial"].ravel(),
226228
)
227229
elif backend is self.RefBackend.TF:
228-
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
230+
return (
231+
ret[0].ravel(),
232+
ret[1].ravel(),
233+
ret[2].ravel(),
234+
ret[3].ravel(),
235+
ret[4].ravel(),
236+
)
229237
elif backend is self.RefBackend.JAX:
230238
return (
231239
ret["energy_redu"].ravel(),
232240
ret["energy"].ravel(),
233241
ret["energy_derv_r"].ravel(),
234242
ret["energy_derv_c_redu"].ravel(),
243+
ret["energy_derv_c"].ravel(),
235244
)
236245
raise ValueError(f"Unknown backend: {backend}")

0 commit comments

Comments
 (0)