Skip to content

Commit 7aaf284

Browse files
authored
feat(jax): zbl (#4301)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced new classes: `DPZBLLinearEnergyAtomicModel` and `PairTabAtomicModel`, enhancing atomic model functionalities. - Added `get_zbl_model` function for constructing `DPZBLModel` from input data. - Improved error handling in vector normalization with `safe_for_vector_norm` and `safe_for_sqrt`. - **Bug Fixes** - Enhanced distance calculations in `format_nlist` to prevent NaN errors. - **Documentation** - Updated comments and docstrings for clarity on recent changes. - **Tests** - Enhanced test support for JAX backend in `test_zbl_ener.py`. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent bfbe2ed commit 7aaf284

File tree

11 files changed

+363
-54
lines changed

11 files changed

+363
-54
lines changed

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Union,
66
)
77

8+
import array_api_compat
89
import numpy as np
910

1011
from deepmd.dpmodel.utils.nlist import (
@@ -69,15 +70,16 @@ def __init__(
6970
self.models = models
7071
sub_model_type_maps = [md.get_type_map() for md in models]
7172
err_msg = []
72-
self.mapping_list = []
73+
mapping_list = []
7374
common_type_map = set(type_map)
7475
self.type_map = type_map
7576
for tpmp in sub_model_type_maps:
7677
if not common_type_map.issubset(set(tpmp)):
7778
err_msg.append(
7879
f"type_map {tpmp} is not a subset of type_map {type_map}"
7980
)
80-
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))
81+
mapping_list.append(self.remap_atype(tpmp, self.type_map))
82+
self.mapping_list = mapping_list
8183
assert len(err_msg) == 0, "\n".join(err_msg)
8284
self.mixed_types_list = [model.mixed_types() for model in self.models]
8385

@@ -212,8 +214,9 @@ def forward_atomic(
212214
result_dict
213215
the result dict, defined by the fitting net output def.
214216
"""
217+
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
215218
nframes, nloc, nnei = nlist.shape
216-
extended_coord = extended_coord.reshape(nframes, -1, 3)
219+
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
217220
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
218221
nlists = build_multiple_neighbor_list(
219222
extended_coord,
@@ -244,10 +247,10 @@ def forward_atomic(
244247
aparam,
245248
)["energy"]
246249
)
247-
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
250+
weights = self._compute_weight(extended_coord, extended_atype, nlists_)
248251

249252
fit_ret = {
250-
"energy": np.sum(np.stack(ener_list) * np.stack(self.weights), axis=0),
253+
"energy": xp.sum(xp.stack(ener_list) * xp.stack(weights), axis=0),
251254
} # (nframes, nloc, 1)
252255
return fit_ret
253256

@@ -320,11 +323,12 @@ def _compute_weight(
320323
nlists_: list[np.ndarray],
321324
) -> list[np.ndarray]:
322325
"""This should be a list of user defined weights that matches the number of models to be combined."""
326+
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_)
323327
nmodels = len(self.models)
324328
nframes, nloc, _ = nlists_[0].shape
325329
# the dtype of weights is the interface data type.
326330
return [
327-
np.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
331+
xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels
328332
for _ in range(nmodels)
329333
]
330334

@@ -442,6 +446,7 @@ def _compute_weight(
442446
self.sw_rmax > self.sw_rmin
443447
), "The upper boundary `sw_rmax` must be greater than the lower boundary `sw_rmin`."
444448

449+
xp = array_api_compat.array_namespace(extended_coord, extended_atype)
445450
dp_nlist = nlists_[0]
446451
zbl_nlist = nlists_[1]
447452

@@ -450,40 +455,40 @@ def _compute_weight(
450455

451456
# use the larger rr based on nlist
452457
nlist_larger = zbl_nlist if zbl_nnei >= dp_nnei else dp_nlist
453-
masked_nlist = np.clip(nlist_larger, 0, None)
458+
masked_nlist = xp.clip(nlist_larger, 0, None)
454459
pairwise_rr = PairTabAtomicModel._get_pairwise_dist(
455460
extended_coord, masked_nlist
456461
)
457462

458-
numerator = np.sum(
459-
np.where(
463+
numerator = xp.sum(
464+
xp.where(
460465
nlist_larger != -1,
461-
pairwise_rr * np.exp(-pairwise_rr / self.smin_alpha),
462-
np.zeros_like(nlist_larger),
466+
pairwise_rr * xp.exp(-pairwise_rr / self.smin_alpha),
467+
xp.zeros_like(nlist_larger),
463468
),
464469
axis=-1,
465470
) # masked nnei will be zero, no need to handle
466-
denominator = np.sum(
467-
np.where(
471+
denominator = xp.sum(
472+
xp.where(
468473
nlist_larger != -1,
469-
np.exp(-pairwise_rr / self.smin_alpha),
470-
np.zeros_like(nlist_larger),
474+
xp.exp(-pairwise_rr / self.smin_alpha),
475+
xp.zeros_like(nlist_larger),
471476
),
472477
axis=-1,
473478
) # handle masked nnei.
474479
with np.errstate(divide="ignore", invalid="ignore"):
475480
sigma = numerator / denominator
476481
u = (sigma - self.sw_rmin) / (self.sw_rmax - self.sw_rmin)
477-
coef = np.zeros_like(u)
482+
coef = xp.zeros_like(u)
478483
left_mask = sigma < self.sw_rmin
479484
mid_mask = (self.sw_rmin <= sigma) & (sigma < self.sw_rmax)
480485
right_mask = sigma >= self.sw_rmax
481-
coef[left_mask] = 1
486+
coef = xp.where(left_mask, xp.ones_like(coef), coef)
482487
with np.errstate(invalid="ignore"):
483488
smooth = -6 * u**5 + 15 * u**4 - 10 * u**3 + 1
484-
coef[mid_mask] = smooth[mid_mask]
485-
coef[right_mask] = 0
489+
coef = xp.where(mid_mask, smooth, coef)
490+
coef = xp.where(right_mask, xp.zeros_like(coef), coef)
486491
# to handle masked atoms
487-
coef = np.where(sigma != 0, coef, np.zeros_like(coef))
492+
coef = xp.where(sigma != 0, coef, xp.zeros_like(coef))
488493
self.zbl_weight = coef
489-
return [1 - np.expand_dims(coef, -1), np.expand_dims(coef, -1)]
494+
return [1 - xp.expand_dims(coef, -1), xp.expand_dims(coef, -1)]

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,19 @@
55
Union,
66
)
77

8+
import array_api_compat
89
import numpy as np
910

11+
from deepmd.dpmodel.array_api import (
12+
xp_take_along_axis,
13+
)
1014
from deepmd.dpmodel.output_def import (
1115
FittingOutputDef,
1216
OutputVariableDef,
1317
)
18+
from deepmd.dpmodel.utils.safe_gradient import (
19+
safe_for_sqrt,
20+
)
1421
from deepmd.utils.pair_tab import (
1522
PairTab,
1623
)
@@ -74,9 +81,10 @@ def __init__(
7481
self.atom_ener = atom_ener
7582

7683
if self.tab_file is not None:
77-
self.tab_info, self.tab_data = self.tab.get()
78-
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
79-
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
84+
tab_info, tab_data = self.tab.get()
85+
nspline, ntypes_tab = tab_info[-2:].astype(int)
86+
self.tab_info = tab_info
87+
self.tab_data = tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
8088
if self.ntypes != ntypes_tab:
8189
raise ValueError(
8290
"The `type_map` provided does not match the number of columns in the table."
@@ -189,8 +197,9 @@ def forward_atomic(
189197
fparam: Optional[np.ndarray] = None,
190198
aparam: Optional[np.ndarray] = None,
191199
) -> dict[str, np.ndarray]:
200+
xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist)
192201
nframes, nloc, nnei = nlist.shape
193-
extended_coord = extended_coord.reshape(nframes, -1, 3)
202+
extended_coord = xp.reshape(extended_coord, (nframes, -1, 3))
194203

195204
# this will mask all -1 in the nlist
196205
mask = nlist >= 0
@@ -200,23 +209,21 @@ def forward_atomic(
200209
pairwise_rr = self._get_pairwise_dist(
201210
extended_coord, masked_nlist
202211
) # (nframes, nloc, nnei)
203-
self.tab_data = self.tab_data.reshape(
204-
self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4
205-
)
206212

207213
# (nframes, nloc, nnei), index type is int64.
208214
j_type = extended_atype[
209-
np.arange(extended_atype.shape[0], dtype=np.int64)[:, None, None],
215+
xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None],
210216
masked_nlist,
211217
]
212218

213219
raw_atomic_energy = self._pair_tabulated_inter(
214220
nlist, atype, j_type, pairwise_rr
215221
)
216-
atomic_energy = 0.5 * np.sum(
217-
np.where(nlist != -1, raw_atomic_energy, np.zeros_like(raw_atomic_energy)),
222+
atomic_energy = 0.5 * xp.sum(
223+
xp.where(nlist != -1, raw_atomic_energy, xp.zeros_like(raw_atomic_energy)),
218224
axis=-1,
219-
).reshape(nframes, nloc, 1)
225+
)
226+
atomic_energy = xp.reshape(atomic_energy, (nframes, nloc, 1))
220227

221228
return {"energy": atomic_energy}
222229

@@ -255,36 +262,42 @@ def _pair_tabulated_inter(
255262
This function is used to calculate the pairwise energy between two atoms.
256263
It uses a table containing cubic spline coefficients calculated in PairTab.
257264
"""
265+
xp = array_api_compat.array_namespace(nlist, i_type, j_type, rr)
258266
nframes, nloc, nnei = nlist.shape
259267
rmin = self.tab_info[0]
260268
hh = self.tab_info[1]
261269
hi = 1.0 / hh
262270

263-
nspline = int(self.tab_info[2] + 0.1)
271+
# jax jit does not support convert to a Python int, so we need to convert to xp.int64.
272+
nspline = (self.tab_info[2] + 0.1).astype(xp.int64)
264273

265274
uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei)
266275

267276
# if nnei of atom 0 has -1 in the nlist, uu would be 0.
268277
# this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
269-
uu = np.where(nlist != -1, uu, nspline + 1)
278+
uu = xp.where(nlist != -1, uu, nspline + 1)
270279

271-
if np.any(uu < 0):
272-
raise Exception("coord go beyond table lower boundary")
280+
# unsupported by jax
281+
# if xp.any(uu < 0):
282+
# raise Exception("coord go beyond table lower boundary")
273283

274-
idx = uu.astype(int)
284+
idx = xp.astype(uu, xp.int64)
275285

276286
uu -= idx
277287
table_coef = self._extract_spline_coefficient(
278288
i_type, j_type, idx, self.tab_data, nspline
279289
)
280-
table_coef = table_coef.reshape(nframes, nloc, nnei, 4)
290+
table_coef = xp.reshape(table_coef, (nframes, nloc, nnei, 4))
281291
ener = self._calculate_ener(table_coef, uu)
282292
# here we need to overwrite energy to zero at rcut and beyond.
283293
mask_beyond_rcut = rr >= self.rcut
284294
# also overwrite values beyond extrapolation to zero
285295
extrapolation_mask = rr >= self.tab.rmin + nspline * self.tab.hh
286-
ener[mask_beyond_rcut] = 0
287-
ener[extrapolation_mask] = 0
296+
ener = xp.where(
297+
xp.logical_or(mask_beyond_rcut, extrapolation_mask),
298+
xp.zeros_like(ener),
299+
ener,
300+
)
288301

289302
return ener
290303

@@ -304,12 +317,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray:
304317
np.ndarray
305318
The pairwise distance between the atoms (nframes, nloc, nnei).
306319
"""
320+
xp = array_api_compat.array_namespace(coords, nlist)
307321
# index type is int64
308-
batch_indices = np.arange(nlist.shape[0], dtype=np.int64)[:, None, None]
322+
batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None]
309323
neighbor_atoms = coords[batch_indices, nlist]
310324
loc_atoms = coords[:, : nlist.shape[1], :]
311325
pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms
312-
pairwise_rr = np.sqrt(np.sum(np.power(pairwise_dr, 2), axis=-1))
326+
pairwise_rr = safe_for_sqrt(xp.sum(xp.power(pairwise_dr, 2), axis=-1))
313327

314328
return pairwise_rr
315329

@@ -319,7 +333,7 @@ def _extract_spline_coefficient(
319333
j_type: np.ndarray,
320334
idx: np.ndarray,
321335
tab_data: np.ndarray,
322-
nspline: int,
336+
nspline: np.int64,
323337
) -> np.ndarray:
324338
"""Extract the spline coefficient from the table.
325339
@@ -341,28 +355,31 @@ def _extract_spline_coefficient(
341355
np.ndarray
342356
The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
343357
"""
358+
xp = array_api_compat.array_namespace(i_type, j_type, idx, tab_data)
344359
# (nframes, nloc, nnei)
345-
expanded_i_type = np.broadcast_to(
346-
i_type[:, :, np.newaxis],
360+
expanded_i_type = xp.broadcast_to(
361+
i_type[:, :, xp.newaxis],
347362
(i_type.shape[0], i_type.shape[1], j_type.shape[-1]),
348363
)
349364

350365
# (nframes, nloc, nnei, nspline, 4)
351366
expanded_tab_data = tab_data[expanded_i_type, j_type]
352367

353368
# (nframes, nloc, nnei, 1, 4)
354-
expanded_idx = np.broadcast_to(
355-
idx[..., np.newaxis, np.newaxis], (*idx.shape, 1, 4)
369+
expanded_idx = xp.broadcast_to(
370+
idx[..., xp.newaxis, xp.newaxis], (*idx.shape, 1, 4)
356371
)
357-
clipped_indices = np.clip(expanded_idx, 0, nspline - 1).astype(int)
372+
clipped_indices = xp.clip(expanded_idx, 0, nspline - 1).astype(int)
358373

359374
# (nframes, nloc, nnei, 4)
360-
final_coef = np.squeeze(
361-
np.take_along_axis(expanded_tab_data, clipped_indices, 3)
375+
final_coef = xp.squeeze(
376+
xp_take_along_axis(expanded_tab_data, clipped_indices, 3)
362377
)
363378

364379
# when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
365-
final_coef[expanded_idx.squeeze() > nspline] = 0
380+
final_coef = xp.where(
381+
expanded_idx.squeeze() > nspline, xp.zeros_like(final_coef), final_coef
382+
)
366383
return final_coef
367384

368385
@staticmethod

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
LayerNorm,
2828
NativeLayer,
2929
)
30+
from deepmd.dpmodel.utils.safe_gradient import (
31+
safe_for_vector_norm,
32+
)
3033
from deepmd.dpmodel.utils.seed import (
3134
child_seed,
3235
)
@@ -943,7 +946,7 @@ def call(
943946
else:
944947
raise NotImplementedError
945948

946-
normed = xp.linalg.vector_norm(
949+
normed = safe_for_vector_norm(
947950
xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True
948951
)
949952
input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum(
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""Safe versions of some functions that have problematic gradients.
3+
4+
Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
5+
for more information.
6+
"""
7+
8+
import array_api_compat
9+
10+
11+
def safe_for_sqrt(x):
12+
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
13+
xp = array_api_compat.array_namespace(x)
14+
mask = x > 0.0
15+
return xp.where(mask, xp.sqrt(xp.where(mask, x, xp.ones_like(x))), xp.zeros_like(x))
16+
17+
18+
def safe_for_vector_norm(x, /, *, axis=None, keepdims=False, ord=2):
19+
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
20+
xp = array_api_compat.array_namespace(x)
21+
mask = xp.sum(xp.square(x), axis=axis, keepdims=True) > 0
22+
if keepdims:
23+
mask_squeezed = mask
24+
else:
25+
mask_squeezed = xp.squeeze(mask, axis=axis)
26+
return xp.where(
27+
mask_squeezed,
28+
xp.linalg.vector_norm(
29+
xp.where(mask, x, xp.ones_like(x)), axis=axis, keepdims=keepdims, ord=ord
30+
),
31+
xp.zeros_like(mask_squeezed, dtype=x.dtype),
32+
)

0 commit comments

Comments
 (0)