diff --git a/deepmd/dpmodel/utils/neighbor_stat.py b/deepmd/dpmodel/utils/neighbor_stat.py index 1bcc894624..0ccb419d68 100644 --- a/deepmd/dpmodel/utils/neighbor_stat.py +++ b/deepmd/dpmodel/utils/neighbor_stat.py @@ -87,7 +87,7 @@ def call( ) assert list(diff.shape) == [nframes, nloc, nall, 3] # remove the diagonal elements - mask = xp.eye(nloc, nall, dtype=xp.bool) + mask = xp.eye(nloc, nall, dtype=xp.bool, device=array_api_compat.device(diff)) mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3)) diff = xp.where(mask, xp.full_like(diff, xp.inf), diff) rr2 = xp.sum(xp.square(diff), axis=-1) diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index a43cf46403..eb5320bad1 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -117,7 +117,9 @@ def build_neighbor_list( assert list(diff.shape) == [batch_size, nloc, nall, 3] rr = xp.linalg.vector_norm(diff, axis=-1) # if central atom has two zero distances, sorting sometimes can not exclude itself - rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :] + rr -= xp.eye(nloc, nall, dtype=diff.dtype, device=array_api_compat.device(diff))[ + xp.newaxis, :, : + ] nlist = xp.argsort(rr, axis=-1) rr = xp.sort(rr, axis=-1) rr = rr[:, :, 1:] @@ -128,11 +130,26 @@ def build_neighbor_list( nlist = nlist[:, :, :nsel] else: rr = xp.concatenate( - [rr, xp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut], + [ + rr, + xp.ones( + [batch_size, nloc, nsel - nnei], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + + rcut, + ], axis=-1, ) nlist = xp.concatenate( - [nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + [ + nlist, + xp.ones( + [batch_size, nloc, nsel - nnei], + dtype=nlist.dtype, + device=array_api_compat.device(nlist), + ), + ], axis=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] @@ -218,7 +235,11 @@ def build_multiple_neighbor_list( return {} nb, nloc, nsel = nlist.shape if nsel < nsels[-1]: - pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + pad = -1 * xp.ones( + (nb, nloc, nsels[-1] - nsel), + dtype=nlist.dtype, + device=array_api_compat.device(nlist), + ) nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] coord1 = xp.reshape(coord, (nb, -1, 3)) @@ -276,7 +297,12 @@ def extend_coord_with_ghosts( xp = array_api_compat.array_namespace(coord, atype) nf, nloc = atype.shape # int64 for index - aidx = xp.tile(xp.arange(nloc, dtype=xp.int64)[xp.newaxis, :], (nf, 1)) + aidx = xp.tile( + xp.arange(nloc, dtype=xp.int64, device=array_api_compat.device(atype))[ + xp.newaxis, : + ], + (nf, 1), + ) if cell is None: nall = nloc extend_coord = coord @@ -288,17 +314,41 @@ def extend_coord_with_ghosts( to_face = to_face_distance(cell) nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64) nbuff = xp.max(nbuff, axis=0) - xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=xp.int64) - yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=xp.int64) - zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=xp.int64) - xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :] + xi = xp.arange( + -int(nbuff[0]), + int(nbuff[0]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + yi = xp.arange( + -int(nbuff[1]), + int(nbuff[1]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + zi = xp.arange( + -int(nbuff[2]), + int(nbuff[2]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + xyz = xp.linalg.outer( + xi, xp.asarray([1, 0, 0], device=array_api_compat.device(xi)) + )[:, xp.newaxis, xp.newaxis, :] xyz = ( xyz - + xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :] + + xp.linalg.outer( + yi, xp.asarray([0, 1, 0], device=array_api_compat.device(yi)) + )[xp.newaxis, :, xp.newaxis, :] ) xyz = ( xyz - + xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :] + + xp.linalg.outer( + zi, xp.asarray([0, 0, 1], device=array_api_compat.device(zi)) + )[xp.newaxis, xp.newaxis, :, :] ) xyz = xp.reshape(xyz, (-1, 3)) xyz = xp.astype(xyz, coord.dtype) diff --git a/deepmd/pt/utils/neighbor_stat.py b/deepmd/pt/utils/neighbor_stat.py index 292a27080b..98a2eabfb6 100644 --- a/deepmd/pt/utils/neighbor_stat.py +++ b/deepmd/pt/utils/neighbor_stat.py @@ -6,111 +6,21 @@ import numpy as np import torch +from deepmd.dpmodel.utils.neighbor_stat import ( + NeighborStatOP, +) from deepmd.pt.utils.auto_batch_size import ( AutoBatchSize, ) from deepmd.pt.utils.env import ( DEVICE, ) -from deepmd.pt.utils.nlist import ( - extend_coord_with_ghosts, -) from deepmd.utils.data_system import ( DeepmdDataSystem, ) from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat -class NeighborStatOP(torch.nn.Module): - """Class for getting neighbor statistics data information. - - Parameters - ---------- - ntypes - The num of atom types - rcut - The cut-off radius - mixed_types : bool, optional - If True, treat neighbors of all types as a single type. - """ - - def __init__( - self, - ntypes: int, - rcut: float, - mixed_types: bool, - ) -> None: - super().__init__() - self.rcut = float(rcut) - self.ntypes = ntypes - self.mixed_types = mixed_types - - def forward( - self, - coord: torch.Tensor, - atype: torch.Tensor, - cell: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Calculate the neareest neighbor distance between atoms, maximum nbor size of - atoms and the output data range of the environment matrix. - - Parameters - ---------- - coord - The coordinates of atoms. - atype - The atom types. - cell - The cell. - - Returns - ------- - torch.Tensor - The minimal squared distance between two atoms, in the shape of (nframes,) - torch.Tensor - The maximal number of neighbors - """ - nframes = coord.shape[0] - coord = coord.view(nframes, -1, 3) - nloc = coord.shape[1] - coord = coord.view(nframes, nloc * 3) - extend_coord, extend_atype, _ = extend_coord_with_ghosts( - coord, atype, cell, self.rcut - ) - - coord1 = extend_coord.reshape(nframes, -1) - nall = coord1.shape[1] // 3 - coord0 = coord1[:, : nloc * 3] - diff = ( - coord1.reshape([nframes, -1, 3])[:, None, :, :] - - coord0.reshape([nframes, -1, 3])[:, :, None, :] - ) - assert list(diff.shape) == [nframes, nloc, nall, 3] - # remove the diagonal elements - mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device) - diff[:, mask] = torch.inf - rr2 = torch.sum(torch.square(diff), dim=-1) - min_rr2, _ = torch.min(rr2, dim=-1) - # count the number of neighbors - if not self.mixed_types: - mask = rr2 < self.rcut**2 - nnei = torch.zeros( - (nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device - ) - for ii in range(self.ntypes): - nnei[:, :, ii] = torch.sum( - mask & extend_atype.eq(ii)[:, None, :], dim=-1 - ) - else: - mask = rr2 < self.rcut**2 - # virtual types (<0) are not counted - nnei = torch.sum(mask & extend_atype.ge(0)[:, None, :], dim=-1).view( - nframes, nloc, 1 - ) - max_nnei, _ = torch.max(nnei, dim=1) - return min_rr2, max_nnei - - class NeighborStat(BaseNeighborStat): """Neighbor statistics using pure NumPy. @@ -131,8 +41,7 @@ def __init__( mixed_type: bool = False, ) -> None: super().__init__(ntypes, rcut, mixed_type) - op = NeighborStatOP(ntypes, rcut, mixed_type) - self.op = torch.jit.script(op) + self.op = NeighborStatOP(ntypes, rcut, mixed_type) self.auto_batch_size = AutoBatchSize() def iterator(