Skip to content
41 changes: 28 additions & 13 deletions deepmd/dpmodel/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
for kk, vv in env_mat.items():
xp = array_api_compat.array_namespace(vv)
stats[kk] = StatItem(
number=vv.size,
number=array_api_compat.size(vv),
sum=float(xp.sum(vv)),
squared_sum=float(xp.sum(xp.square(vv))),
)
Expand Down Expand Up @@ -96,6 +96,18 @@ def iter(
dict[str, StatItem]
The statistics of the environment matrix.
"""
if self.last_dim == 4:
radial_only = False
elif self.last_dim == 1:
radial_only = True
else:
raise ValueError(
"last_dim should be 1 for raial-only or 4 for full descriptor."
)
if len(data) == 0:
# workaround to fix IndexError: list index out of range
yield from ()
return
xp = array_api_compat.array_namespace(data[0]["coord"])
zero_mean = xp.zeros(
(
Expand All @@ -104,6 +116,7 @@ def iter(
self.last_dim,
),
dtype=get_xp_precision(xp, "global"),
device=array_api_compat.device(data[0]["coord"]),
)
one_stddev = xp.ones(
(
Expand All @@ -112,15 +125,8 @@ def iter(
self.last_dim,
),
dtype=get_xp_precision(xp, "global"),
device=array_api_compat.device(data[0]["coord"]),
)
if self.last_dim == 4:
radial_only = False
elif self.last_dim == 1:
radial_only = True
else:
raise ValueError(
"last_dim should be 1 for raial-only or 4 for full descriptor."
)
for system in data:
coord, atype, box, natoms = (
system["coord"],
Expand Down Expand Up @@ -175,16 +181,25 @@ def iter(
type_idx = xp.equal(
xp.reshape(atype, (1, -1)),
xp.reshape(
xp.arange(self.descriptor.get_ntypes(), dtype=xp.int32),
xp.arange(
self.descriptor.get_ntypes(),
dtype=xp.int32,
device=array_api_compat.device(atype),
),
(-1, 1),
),
)
if "pair_exclude_types" in system:
pair_exclude_mask = PairExcludeMask(
self.descriptor.get_ntypes(), system["pair_exclude_types"]
)
pair_exclude_mask.type_mask = xp.asarray(
pair_exclude_mask.type_mask,
device=array_api_compat.device(atype),
)
# shape: (1, nloc, nnei)
exclude_mask = xp.reshape(
PairExcludeMask(
self.descriptor.get_ntypes(), system["pair_exclude_types"]
).build_type_exclude_mask(nlist, extended_atype),
pair_exclude_mask.build_type_exclude_mask(nlist, extended_atype),
(1, coord.shape[0] * coord.shape[1], -1),
)
# shape: (ntypes, nloc, nnei)
Expand Down
11 changes: 10 additions & 1 deletion deepmd/dpmodel/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,16 @@ def build_type_exclude_mask(
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = xp.concat(
[atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
[
atype_ext,
self.ntypes
* xp.ones(
[nf, 1],
dtype=atype_ext.dtype,
device=array_api_compat.device(atype_ext),
),
],
axis=-1,
)
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
# nf x nloc x nnei
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def build_neighbor_list(
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord.size > 0:
if array_api_compat.size(coord) > 0:
xmax = xp.max(coord) + 2.0 * rcut
else:
xmax = 2.0 * rcut
Expand Down
4 changes: 3 additions & 1 deletion deepmd/dpmodel/utils/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def normalize_coord(
"""
xp = array_api_compat.array_namespace(coord, cell)
icoord = phys2inter(coord, cell)
icoord = xp.remainder(icoord, xp.asarray(1.0))
icoord = xp.remainder(
icoord, xp.ones((), dtype=icoord.dtype, device=array_api_compat.device(icoord))
)
return inter2phys(icoord, cell)


Expand Down
236 changes: 7 additions & 229 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
@@ -1,234 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Iterator,
)
from typing import (
TYPE_CHECKING,
)

import numpy as np
import torch

from deepmd.common import (
get_hash,
)
from deepmd.pt.model.descriptor.env_mat import (
prod_env_mat,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.nlist import (
extend_input_and_build_neighbor_list,
from deepmd.dpmodel.utils.env_mat_stat import (
EnvMatStat,
EnvMatStatSe,
)
from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat
from deepmd.utils.env_mat_stat import (
StatItem,
)

if TYPE_CHECKING:
from deepmd.pt.model.descriptor import (
DescriptorBlock,
)


class EnvMatStat(BaseEnvMatStat):
def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]:
"""Compute the statistics of the environment matrix for a single system.

Parameters
----------
env_mat : torch.Tensor
The environment matrix.

Returns
-------
dict[str, StatItem]
The statistics of the environment matrix.
"""
stats = {}
for kk, vv in env_mat.items():
stats[kk] = StatItem(
number=vv.numel(),
sum=vv.sum().item(),
squared_sum=torch.square(vv).sum().item(),
)
return stats


class EnvMatStatSe(EnvMatStat):
"""Environmental matrix statistics for the se_a/se_r environmental matrix.

Parameters
----------
descriptor : DescriptorBlock
The descriptor of the model.
"""

def __init__(self, descriptor: "DescriptorBlock") -> None:
super().__init__()
self.descriptor = descriptor
self.last_dim = (
self.descriptor.ndescrpt // self.descriptor.nnei
) # se_r=1, se_a=4

def iter(
self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]]
) -> Iterator[dict[str, StatItem]]:
"""Get the iterator of the environment matrix.

Parameters
----------
data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]]
The data.

Yields
------
dict[str, StatItem]
The statistics of the environment matrix.
"""
zero_mean = torch.zeros(
self.descriptor.get_ntypes(),
self.descriptor.get_nsel(),
self.last_dim,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
one_stddev = torch.ones(
self.descriptor.get_ntypes(),
self.descriptor.get_nsel(),
self.last_dim,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
if self.last_dim == 4:
radial_only = False
elif self.last_dim == 1:
radial_only = True
else:
raise ValueError(
"last_dim should be 1 for raial-only or 4 for full descriptor."
)
for system in data:
coord, atype, box, natoms = (
system["coord"],
system["atype"],
system["box"],
system["natoms"],
)
(
extended_coord,
extended_atype,
mapping,
nlist,
) = extend_input_and_build_neighbor_list(
coord,
atype,
self.descriptor.get_rcut(),
self.descriptor.get_sel(),
mixed_types=self.descriptor.mixed_types(),
box=box,
)
env_mat, _, _ = prod_env_mat(
extended_coord,
nlist,
atype,
zero_mean,
one_stddev,
self.descriptor.get_rcut(),
self.descriptor.get_rcut_smth(),
radial_only,
protection=self.descriptor.get_env_protection(),
)
# apply excluded_types
exclude_mask = self.descriptor.emask(nlist, extended_atype)
env_mat *= exclude_mask.unsqueeze(-1)
# reshape to nframes * nloc at the atom level,
# so nframes/mixed_type do not matter
env_mat = env_mat.view(
coord.shape[0] * coord.shape[1],
self.descriptor.get_nsel(),
self.last_dim,
)
atype = atype.view(coord.shape[0] * coord.shape[1])
# (1, nloc) eq (ntypes, 1), so broadcast is possible
# shape: (ntypes, nloc)
type_idx = torch.eq(
atype.view(1, -1),
torch.arange(
self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32
).view(-1, 1),
)
if "pair_exclude_types" in system:
# shape: (1, nloc, nnei)
exclude_mask = PairExcludeMask(
self.descriptor.get_ntypes(), system["pair_exclude_types"]
)(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1)
# shape: (ntypes, nloc, nnei)
type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask)
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[type_idx[type_i]]
dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
if self.last_dim == 4:
env_mats[f"a_{type_i}"] = dd[:, 1:]
yield self.compute_stat(env_mats)

def get_hash(self) -> str:
"""Get the hash of the environment matrix.

Returns
-------
str
The hash of the environment matrix.
"""
dscpt_type = "se_a" if self.last_dim == 4 else "se_r"
return get_hash(
{
"type": dscpt_type,
"ntypes": self.descriptor.get_ntypes(),
"rcut": round(self.descriptor.get_rcut(), 2),
"rcut_smth": round(self.descriptor.rcut_smth, 2),
"nsel": self.descriptor.get_nsel(),
"sel": self.descriptor.get_sel(),
"mixed_types": self.descriptor.mixed_types(),
}
)

def __call__(self) -> tuple[np.ndarray, np.ndarray]:
avgs = self.get_avg()
stds = self.get_std()

all_davg = []
all_dstd = []

for type_i in range(self.descriptor.get_ntypes()):
if self.last_dim == 4:
davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]]
dstdunit = [
[
stds[f"r_{type_i}"],
stds[f"a_{type_i}"],
stds[f"a_{type_i}"],
stds[f"a_{type_i}"],
]
]
elif self.last_dim == 1:
davgunit = [[avgs[f"r_{type_i}"]]]
dstdunit = [
[
stds[f"r_{type_i}"],
]
]
davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1])
dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1])
all_davg.append(davg)
all_dstd.append(dstd)

mean = np.stack(all_davg)
stddev = np.stack(all_dstd)
return mean, stddev
__all__ = [
"EnvMatStat",
"EnvMatStatSe",
]
2 changes: 2 additions & 0 deletions deepmd/pt/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ def forward(
type_ij = type_ij.view(nf, nloc * nnei)
mask = self.type_mask[type_ij].view(nf, nloc, nnei).to(atype_ext.device)
return mask

build_type_exclude_mask = forward