55 Union ,
66)
77
8+ import array_api_compat
89import numpy as np
910
11+ from deepmd .dpmodel .array_api import (
12+ xp_take_along_axis ,
13+ )
1014from deepmd .dpmodel .output_def import (
1115 FittingOutputDef ,
1216 OutputVariableDef ,
1317)
18+ from deepmd .dpmodel .utils .safe_gradient import (
19+ safe_for_sqrt ,
20+ )
1421from deepmd .utils .pair_tab import (
1522 PairTab ,
1623)
@@ -74,9 +81,10 @@ def __init__(
7481 self .atom_ener = atom_ener
7582
7683 if self .tab_file is not None :
77- self .tab_info , self .tab_data = self .tab .get ()
78- nspline , ntypes_tab = self .tab_info [- 2 :].astype (int )
79- self .tab_data = self .tab_data .reshape (ntypes_tab , ntypes_tab , nspline , 4 )
84+ tab_info , tab_data = self .tab .get ()
85+ nspline , ntypes_tab = tab_info [- 2 :].astype (int )
86+ self .tab_info = tab_info
87+ self .tab_data = tab_data .reshape (ntypes_tab , ntypes_tab , nspline , 4 )
8088 if self .ntypes != ntypes_tab :
8189 raise ValueError (
8290 "The `type_map` provided does not match the number of columns in the table."
@@ -189,8 +197,9 @@ def forward_atomic(
189197 fparam : Optional [np .ndarray ] = None ,
190198 aparam : Optional [np .ndarray ] = None ,
191199 ) -> dict [str , np .ndarray ]:
200+ xp = array_api_compat .array_namespace (extended_coord , extended_atype , nlist )
192201 nframes , nloc , nnei = nlist .shape
193- extended_coord = extended_coord .reshape (nframes , - 1 , 3 )
202+ extended_coord = xp .reshape (extended_coord , ( nframes , - 1 , 3 ) )
194203
195204 # this will mask all -1 in the nlist
196205 mask = nlist >= 0
@@ -200,23 +209,21 @@ def forward_atomic(
200209 pairwise_rr = self ._get_pairwise_dist (
201210 extended_coord , masked_nlist
202211 ) # (nframes, nloc, nnei)
203- self .tab_data = self .tab_data .reshape (
204- self .tab .ntypes , self .tab .ntypes , self .tab .nspline , 4
205- )
206212
207213 # (nframes, nloc, nnei), index type is int64.
208214 j_type = extended_atype [
209- np .arange (extended_atype .shape [0 ], dtype = np .int64 )[:, None , None ],
215+ xp .arange (extended_atype .shape [0 ], dtype = xp .int64 )[:, None , None ],
210216 masked_nlist ,
211217 ]
212218
213219 raw_atomic_energy = self ._pair_tabulated_inter (
214220 nlist , atype , j_type , pairwise_rr
215221 )
216- atomic_energy = 0.5 * np .sum (
217- np .where (nlist != - 1 , raw_atomic_energy , np .zeros_like (raw_atomic_energy )),
222+ atomic_energy = 0.5 * xp .sum (
223+ xp .where (nlist != - 1 , raw_atomic_energy , xp .zeros_like (raw_atomic_energy )),
218224 axis = - 1 ,
219- ).reshape (nframes , nloc , 1 )
225+ )
226+ atomic_energy = xp .reshape (atomic_energy , (nframes , nloc , 1 ))
220227
221228 return {"energy" : atomic_energy }
222229
@@ -255,36 +262,42 @@ def _pair_tabulated_inter(
255262 This function is used to calculate the pairwise energy between two atoms.
256263 It uses a table containing cubic spline coefficients calculated in PairTab.
257264 """
265+ xp = array_api_compat .array_namespace (nlist , i_type , j_type , rr )
258266 nframes , nloc , nnei = nlist .shape
259267 rmin = self .tab_info [0 ]
260268 hh = self .tab_info [1 ]
261269 hi = 1.0 / hh
262270
263- nspline = int (self .tab_info [2 ] + 0.1 )
271+ # jax jit does not support convert to a Python int, so we need to convert to xp.int64.
272+ nspline = (self .tab_info [2 ] + 0.1 ).astype (xp .int64 )
264273
265274 uu = (rr - rmin ) * hi # this is broadcasted to (nframes,nloc,nnei)
266275
267276 # if nnei of atom 0 has -1 in the nlist, uu would be 0.
268277 # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms.
269- uu = np .where (nlist != - 1 , uu , nspline + 1 )
278+ uu = xp .where (nlist != - 1 , uu , nspline + 1 )
270279
271- if np .any (uu < 0 ):
272- raise Exception ("coord go beyond table lower boundary" )
280+ # unsupported by jax
281+ # if xp.any(uu < 0):
282+ # raise Exception("coord go beyond table lower boundary")
273283
274- idx = uu .astype (int )
284+ idx = xp .astype (uu , xp . int64 )
275285
276286 uu -= idx
277287 table_coef = self ._extract_spline_coefficient (
278288 i_type , j_type , idx , self .tab_data , nspline
279289 )
280- table_coef = table_coef .reshape (nframes , nloc , nnei , 4 )
290+ table_coef = xp .reshape (table_coef , ( nframes , nloc , nnei , 4 ) )
281291 ener = self ._calculate_ener (table_coef , uu )
282292 # here we need to overwrite energy to zero at rcut and beyond.
283293 mask_beyond_rcut = rr >= self .rcut
284294 # also overwrite values beyond extrapolation to zero
285295 extrapolation_mask = rr >= self .tab .rmin + nspline * self .tab .hh
286- ener [mask_beyond_rcut ] = 0
287- ener [extrapolation_mask ] = 0
296+ ener = xp .where (
297+ xp .logical_or (mask_beyond_rcut , extrapolation_mask ),
298+ xp .zeros_like (ener ),
299+ ener ,
300+ )
288301
289302 return ener
290303
@@ -304,12 +317,13 @@ def _get_pairwise_dist(coords: np.ndarray, nlist: np.ndarray) -> np.ndarray:
304317 np.ndarray
305318 The pairwise distance between the atoms (nframes, nloc, nnei).
306319 """
320+ xp = array_api_compat .array_namespace (coords , nlist )
307321 # index type is int64
308- batch_indices = np .arange (nlist .shape [0 ], dtype = np .int64 )[:, None , None ]
322+ batch_indices = xp .arange (nlist .shape [0 ], dtype = xp .int64 )[:, None , None ]
309323 neighbor_atoms = coords [batch_indices , nlist ]
310324 loc_atoms = coords [:, : nlist .shape [1 ], :]
311325 pairwise_dr = loc_atoms [:, :, None , :] - neighbor_atoms
312- pairwise_rr = np . sqrt ( np .sum (np .power (pairwise_dr , 2 ), axis = - 1 ))
326+ pairwise_rr = safe_for_sqrt ( xp .sum (xp .power (pairwise_dr , 2 ), axis = - 1 ))
313327
314328 return pairwise_rr
315329
@@ -319,7 +333,7 @@ def _extract_spline_coefficient(
319333 j_type : np .ndarray ,
320334 idx : np .ndarray ,
321335 tab_data : np .ndarray ,
322- nspline : int ,
336+ nspline : np . int64 ,
323337 ) -> np .ndarray :
324338 """Extract the spline coefficient from the table.
325339
@@ -341,28 +355,31 @@ def _extract_spline_coefficient(
341355 np.ndarray
342356 The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed.
343357 """
358+ xp = array_api_compat .array_namespace (i_type , j_type , idx , tab_data )
344359 # (nframes, nloc, nnei)
345- expanded_i_type = np .broadcast_to (
346- i_type [:, :, np .newaxis ],
360+ expanded_i_type = xp .broadcast_to (
361+ i_type [:, :, xp .newaxis ],
347362 (i_type .shape [0 ], i_type .shape [1 ], j_type .shape [- 1 ]),
348363 )
349364
350365 # (nframes, nloc, nnei, nspline, 4)
351366 expanded_tab_data = tab_data [expanded_i_type , j_type ]
352367
353368 # (nframes, nloc, nnei, 1, 4)
354- expanded_idx = np .broadcast_to (
355- idx [..., np .newaxis , np .newaxis ], (* idx .shape , 1 , 4 )
369+ expanded_idx = xp .broadcast_to (
370+ idx [..., xp .newaxis , xp .newaxis ], (* idx .shape , 1 , 4 )
356371 )
357- clipped_indices = np .clip (expanded_idx , 0 , nspline - 1 ).astype (int )
372+ clipped_indices = xp .clip (expanded_idx , 0 , nspline - 1 ).astype (int )
358373
359374 # (nframes, nloc, nnei, 4)
360- final_coef = np .squeeze (
361- np . take_along_axis (expanded_tab_data , clipped_indices , 3 )
375+ final_coef = xp .squeeze (
376+ xp_take_along_axis (expanded_tab_data , clipped_indices , 3 )
362377 )
363378
364379 # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`.
365- final_coef [expanded_idx .squeeze () > nspline ] = 0
380+ final_coef = xp .where (
381+ expanded_idx .squeeze () > nspline , xp .zeros_like (final_coef ), final_coef
382+ )
366383 return final_coef
367384
368385 @staticmethod
0 commit comments