diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index b8befa0087..37a69ea1b1 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -58,7 +58,7 @@ def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]: for kk, vv in env_mat.items(): xp = array_api_compat.array_namespace(vv) stats[kk] = StatItem( - number=vv.size, + number=array_api_compat.size(vv), sum=float(xp.sum(vv)), squared_sum=float(xp.sum(xp.square(vv))), ) @@ -96,6 +96,18 @@ def iter( dict[str, StatItem] The statistics of the environment matrix. """ + if self.last_dim == 4: + radial_only = False + elif self.last_dim == 1: + radial_only = True + else: + raise ValueError( + "last_dim should be 1 for raial-only or 4 for full descriptor." + ) + if len(data) == 0: + # workaround to fix IndexError: list index out of range + yield from () + return xp = array_api_compat.array_namespace(data[0]["coord"]) zero_mean = xp.zeros( ( @@ -104,6 +116,7 @@ def iter( self.last_dim, ), dtype=get_xp_precision(xp, "global"), + device=array_api_compat.device(data[0]["coord"]), ) one_stddev = xp.ones( ( @@ -112,15 +125,8 @@ def iter( self.last_dim, ), dtype=get_xp_precision(xp, "global"), + device=array_api_compat.device(data[0]["coord"]), ) - if self.last_dim == 4: - radial_only = False - elif self.last_dim == 1: - radial_only = True - else: - raise ValueError( - "last_dim should be 1 for raial-only or 4 for full descriptor." - ) for system in data: coord, atype, box, natoms = ( system["coord"], @@ -175,16 +181,25 @@ def iter( type_idx = xp.equal( xp.reshape(atype, (1, -1)), xp.reshape( - xp.arange(self.descriptor.get_ntypes(), dtype=xp.int32), + xp.arange( + self.descriptor.get_ntypes(), + dtype=xp.int32, + device=array_api_compat.device(atype), + ), (-1, 1), ), ) if "pair_exclude_types" in system: + pair_exclude_mask = PairExcludeMask( + self.descriptor.get_ntypes(), system["pair_exclude_types"] + ) + pair_exclude_mask.type_mask = xp.asarray( + pair_exclude_mask.type_mask, + device=array_api_compat.device(atype), + ) # shape: (1, nloc, nnei) exclude_mask = xp.reshape( - PairExcludeMask( - self.descriptor.get_ntypes(), system["pair_exclude_types"] - ).build_type_exclude_mask(nlist, extended_atype), + pair_exclude_mask.build_type_exclude_mask(nlist, extended_atype), (1, coord.shape[0] * coord.shape[1], -1), ) # shape: (ntypes, nloc, nnei) diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 9d8f0c8572..4a65d33775 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -120,7 +120,16 @@ def build_type_exclude_mask( nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 ae = xp.concat( - [atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 + [ + atype_ext, + self.ntypes + * xp.ones( + [nf, 1], + dtype=atype_ext.dtype, + device=array_api_compat.device(atype_ext), + ), + ], + axis=-1, ) type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1) # nf x nloc x nnei diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index eb5320bad1..55fb1a73f9 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -96,7 +96,7 @@ def build_neighbor_list( nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. - if coord.size > 0: + if array_api_compat.size(coord) > 0: xmax = xp.max(coord) + 2.0 * rcut else: xmax = 2.0 * rcut diff --git a/deepmd/dpmodel/utils/region.py b/deepmd/dpmodel/utils/region.py index 6d8dfebf88..61786c33a0 100644 --- a/deepmd/dpmodel/utils/region.py +++ b/deepmd/dpmodel/utils/region.py @@ -74,7 +74,9 @@ def normalize_coord( """ xp = array_api_compat.array_namespace(coord, cell) icoord = phys2inter(coord, cell) - icoord = xp.remainder(icoord, xp.asarray(1.0)) + icoord = xp.remainder( + icoord, xp.ones((), dtype=icoord.dtype, device=array_api_compat.device(icoord)) + ) return inter2phys(icoord, cell) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 01822c7f3f..cfb5da0dea 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -1,234 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from collections.abc import ( - Iterator, -) -from typing import ( - TYPE_CHECKING, -) -import numpy as np -import torch -from deepmd.common import ( - get_hash, -) -from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, +from deepmd.dpmodel.utils.env_mat_stat import ( + EnvMatStat, + EnvMatStatSe, ) -from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat -from deepmd.utils.env_mat_stat import ( - StatItem, -) - -if TYPE_CHECKING: - from deepmd.pt.model.descriptor import ( - DescriptorBlock, - ) - - -class EnvMatStat(BaseEnvMatStat): - def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]: - """Compute the statistics of the environment matrix for a single system. - - Parameters - ---------- - env_mat : torch.Tensor - The environment matrix. - - Returns - ------- - dict[str, StatItem] - The statistics of the environment matrix. - """ - stats = {} - for kk, vv in env_mat.items(): - stats[kk] = StatItem( - number=vv.numel(), - sum=vv.sum().item(), - squared_sum=torch.square(vv).sum().item(), - ) - return stats - - -class EnvMatStatSe(EnvMatStat): - """Environmental matrix statistics for the se_a/se_r environmental matrix. - - Parameters - ---------- - descriptor : DescriptorBlock - The descriptor of the model. - """ - - def __init__(self, descriptor: "DescriptorBlock") -> None: - super().__init__() - self.descriptor = descriptor - self.last_dim = ( - self.descriptor.ndescrpt // self.descriptor.nnei - ) # se_r=1, se_a=4 - - def iter( - self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]] - ) -> Iterator[dict[str, StatItem]]: - """Get the iterator of the environment matrix. - - Parameters - ---------- - data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]] - The data. - - Yields - ------ - dict[str, StatItem] - The statistics of the environment matrix. - """ - zero_mean = torch.zeros( - self.descriptor.get_ntypes(), - self.descriptor.get_nsel(), - self.last_dim, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - one_stddev = torch.ones( - self.descriptor.get_ntypes(), - self.descriptor.get_nsel(), - self.last_dim, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - if self.last_dim == 4: - radial_only = False - elif self.last_dim == 1: - radial_only = True - else: - raise ValueError( - "last_dim should be 1 for raial-only or 4 for full descriptor." - ) - for system in data: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.descriptor.get_rcut(), - self.descriptor.get_sel(), - mixed_types=self.descriptor.mixed_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat( - extended_coord, - nlist, - atype, - zero_mean, - one_stddev, - self.descriptor.get_rcut(), - self.descriptor.get_rcut_smth(), - radial_only, - protection=self.descriptor.get_env_protection(), - ) - # apply excluded_types - exclude_mask = self.descriptor.emask(nlist, extended_atype) - env_mat *= exclude_mask.unsqueeze(-1) - # reshape to nframes * nloc at the atom level, - # so nframes/mixed_type do not matter - env_mat = env_mat.view( - coord.shape[0] * coord.shape[1], - self.descriptor.get_nsel(), - self.last_dim, - ) - atype = atype.view(coord.shape[0] * coord.shape[1]) - # (1, nloc) eq (ntypes, 1), so broadcast is possible - # shape: (ntypes, nloc) - type_idx = torch.eq( - atype.view(1, -1), - torch.arange( - self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 - ).view(-1, 1), - ) - if "pair_exclude_types" in system: - # shape: (1, nloc, nnei) - exclude_mask = PairExcludeMask( - self.descriptor.get_ntypes(), system["pair_exclude_types"] - )(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1) - # shape: (ntypes, nloc, nnei) - type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask) - for type_i in range(self.descriptor.get_ntypes()): - dd = env_mat[type_idx[type_i]] - dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4 - env_mats = {} - env_mats[f"r_{type_i}"] = dd[:, :1] - if self.last_dim == 4: - env_mats[f"a_{type_i}"] = dd[:, 1:] - yield self.compute_stat(env_mats) - - def get_hash(self) -> str: - """Get the hash of the environment matrix. - - Returns - ------- - str - The hash of the environment matrix. - """ - dscpt_type = "se_a" if self.last_dim == 4 else "se_r" - return get_hash( - { - "type": dscpt_type, - "ntypes": self.descriptor.get_ntypes(), - "rcut": round(self.descriptor.get_rcut(), 2), - "rcut_smth": round(self.descriptor.rcut_smth, 2), - "nsel": self.descriptor.get_nsel(), - "sel": self.descriptor.get_sel(), - "mixed_types": self.descriptor.mixed_types(), - } - ) - - def __call__(self) -> tuple[np.ndarray, np.ndarray]: - avgs = self.get_avg() - stds = self.get_std() - - all_davg = [] - all_dstd = [] - - for type_i in range(self.descriptor.get_ntypes()): - if self.last_dim == 4: - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - ] - ] - elif self.last_dim == 1: - davgunit = [[avgs[f"r_{type_i}"]]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - ] - ] - davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) - dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) - all_davg.append(davg) - all_dstd.append(dstd) - mean = np.stack(all_davg) - stddev = np.stack(all_dstd) - return mean, stddev +__all__ = [ + "EnvMatStat", + "EnvMatStatSe", +] diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py index cf39220f1b..451b32252c 100644 --- a/deepmd/pt/utils/exclude_mask.py +++ b/deepmd/pt/utils/exclude_mask.py @@ -147,3 +147,5 @@ def forward( type_ij = type_ij.view(nf, nloc * nnei) mask = self.type_mask[type_ij].view(nf, nloc, nnei).to(atype_ext.device) return mask + + build_type_exclude_mask = forward