Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def call(
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = xp.eye(nloc, nall, dtype=xp.bool)
mask = xp.eye(nloc, nall, dtype=xp.bool, device=array_api_compat.device(diff))
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
rr2 = xp.sum(xp.square(diff), axis=-1)
Expand Down
72 changes: 61 additions & 11 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def build_neighbor_list(
assert list(diff.shape) == [batch_size, nloc, nall, 3]
rr = xp.linalg.vector_norm(diff, axis=-1)
# if central atom has two zero distances, sorting sometimes can not exclude itself
rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :]
rr -= xp.eye(nloc, nall, dtype=diff.dtype, device=array_api_compat.device(diff))[
xp.newaxis, :, :
]
nlist = xp.argsort(rr, axis=-1)
rr = xp.sort(rr, axis=-1)
rr = rr[:, :, 1:]
Expand All @@ -128,11 +130,26 @@ def build_neighbor_list(
nlist = nlist[:, :, :nsel]
else:
rr = xp.concatenate(
[rr, xp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
[
rr,
xp.ones(
[batch_size, nloc, nsel - nnei],
dtype=rr.dtype,
device=array_api_compat.device(rr),
)
+ rcut,
],
axis=-1,
)
nlist = xp.concatenate(
[nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
[
nlist,
xp.ones(
[batch_size, nloc, nsel - nnei],
dtype=nlist.dtype,
device=array_api_compat.device(nlist),
),
],
axis=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
Expand Down Expand Up @@ -218,7 +235,11 @@ def build_multiple_neighbor_list(
return {}
nb, nloc, nsel = nlist.shape
if nsel < nsels[-1]:
pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype)
pad = -1 * xp.ones(
(nb, nloc, nsels[-1] - nsel),
dtype=nlist.dtype,
device=array_api_compat.device(nlist),
)
nlist = xp.concat([nlist, pad], axis=-1)
nsel = nsels[-1]
coord1 = xp.reshape(coord, (nb, -1, 3))
Expand Down Expand Up @@ -276,7 +297,12 @@ def extend_coord_with_ghosts(
xp = array_api_compat.array_namespace(coord, atype)
nf, nloc = atype.shape
# int64 for index
aidx = xp.tile(xp.arange(nloc, dtype=xp.int64)[xp.newaxis, :], (nf, 1))
aidx = xp.tile(
xp.arange(nloc, dtype=xp.int64, device=array_api_compat.device(atype))[
xp.newaxis, :
],
(nf, 1),
)
if cell is None:
nall = nloc
extend_coord = coord
Expand All @@ -288,17 +314,41 @@ def extend_coord_with_ghosts(
to_face = to_face_distance(cell)
nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64)
nbuff = xp.max(nbuff, axis=0)
xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=xp.int64)
yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=xp.int64)
zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=xp.int64)
xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :]
xi = xp.arange(
-int(nbuff[0]),
int(nbuff[0]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
yi = xp.arange(
-int(nbuff[1]),
int(nbuff[1]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
zi = xp.arange(
-int(nbuff[2]),
int(nbuff[2]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
xyz = xp.linalg.outer(
xi, xp.asarray([1, 0, 0], device=array_api_compat.device(xi))
)[:, xp.newaxis, xp.newaxis, :]
xyz = (
xyz
+ xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :]
+ xp.linalg.outer(
yi, xp.asarray([0, 1, 0], device=array_api_compat.device(yi))
)[xp.newaxis, :, xp.newaxis, :]
)
xyz = (
xyz
+ xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :]
+ xp.linalg.outer(
zi, xp.asarray([0, 0, 1], device=array_api_compat.device(zi))
)[xp.newaxis, xp.newaxis, :, :]
)
xyz = xp.reshape(xyz, (-1, 3))
xyz = xp.astype(xyz, coord.dtype)
Expand Down
99 changes: 4 additions & 95 deletions deepmd/pt/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,111 +6,21 @@
import numpy as np
import torch

from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStatOP,
)
from deepmd.pt.utils.auto_batch_size import (
AutoBatchSize,
)
from deepmd.pt.utils.env import (
DEVICE,
)
from deepmd.pt.utils.nlist import (
extend_coord_with_ghosts,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


class NeighborStatOP(torch.nn.Module):
"""Class for getting neighbor statistics data information.

Parameters
----------
ntypes
The num of atom types
rcut
The cut-off radius
mixed_types : bool, optional
If True, treat neighbors of all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
mixed_types: bool,
) -> None:
super().__init__()
self.rcut = float(rcut)
self.ntypes = ntypes
self.mixed_types = mixed_types

def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
cell: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the neareest neighbor distance between atoms, maximum nbor size of
atoms and the output data range of the environment matrix.

Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.

Returns
-------
torch.Tensor
The minimal squared distance between two atoms, in the shape of (nframes,)
torch.Tensor
The maximal number of neighbors
"""
nframes = coord.shape[0]
coord = coord.view(nframes, -1, 3)
nloc = coord.shape[1]
coord = coord.view(nframes, nloc * 3)
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut
)

coord1 = extend_coord.reshape(nframes, -1)
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
coord1.reshape([nframes, -1, 3])[:, None, :, :]
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device)
diff[:, mask] = torch.inf
rr2 = torch.sum(torch.square(diff), dim=-1)
min_rr2, _ = torch.min(rr2, dim=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = torch.zeros(
(nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device
)
for ii in range(self.ntypes):
nnei[:, :, ii] = torch.sum(
mask & extend_atype.eq(ii)[:, None, :], dim=-1
)
else:
mask = rr2 < self.rcut**2
# virtual types (<0) are not counted
nnei = torch.sum(mask & extend_atype.ge(0)[:, None, :], dim=-1).view(
nframes, nloc, 1
)
max_nnei, _ = torch.max(nnei, dim=1)
return min_rr2, max_nnei


class NeighborStat(BaseNeighborStat):
"""Neighbor statistics using pure NumPy.

Expand All @@ -131,8 +41,7 @@ def __init__(
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, mixed_type)
op = NeighborStatOP(ntypes, rcut, mixed_type)
self.op = torch.jit.script(op)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()

def iterator(
Expand Down