From 0e5b449bafe7195f6bf45b07695aa7638fcd00a3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Jan 2026 19:23:06 +0800 Subject: [PATCH 01/10] refactor(pt): reuse env mat stat --- deepmd/pt/utils/env_mat_stat.py | 235 +------------------------------- 1 file changed, 7 insertions(+), 228 deletions(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 01822c7f3f..ee2bf043d2 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -1,234 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from collections.abc import ( - Iterator, -) -from typing import ( - TYPE_CHECKING, -) -import numpy as np -import torch -from deepmd.common import ( - get_hash, -) -from deepmd.pt.model.descriptor.env_mat import ( - prod_env_mat, -) -from deepmd.pt.utils import ( - env, -) -from deepmd.pt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, +from deepmd.dpmodel.utils.env_mat_stat import ( + EnvMatStat, + EnvMatStatSe, ) -from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat -from deepmd.utils.env_mat_stat import ( - StatItem, -) - -if TYPE_CHECKING: - from deepmd.pt.model.descriptor import ( - DescriptorBlock, - ) - - -class EnvMatStat(BaseEnvMatStat): - def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]: - """Compute the statistics of the environment matrix for a single system. - - Parameters - ---------- - env_mat : torch.Tensor - The environment matrix. - - Returns - ------- - dict[str, StatItem] - The statistics of the environment matrix. - """ - stats = {} - for kk, vv in env_mat.items(): - stats[kk] = StatItem( - number=vv.numel(), - sum=vv.sum().item(), - squared_sum=torch.square(vv).sum().item(), - ) - return stats - - -class EnvMatStatSe(EnvMatStat): - """Environmental matrix statistics for the se_a/se_r environmental matrix. - - Parameters - ---------- - descriptor : DescriptorBlock - The descriptor of the model. - """ - - def __init__(self, descriptor: "DescriptorBlock") -> None: - super().__init__() - self.descriptor = descriptor - self.last_dim = ( - self.descriptor.ndescrpt // self.descriptor.nnei - ) # se_r=1, se_a=4 - - def iter( - self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]] - ) -> Iterator[dict[str, StatItem]]: - """Get the iterator of the environment matrix. - - Parameters - ---------- - data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]] - The data. - - Yields - ------ - dict[str, StatItem] - The statistics of the environment matrix. - """ - zero_mean = torch.zeros( - self.descriptor.get_ntypes(), - self.descriptor.get_nsel(), - self.last_dim, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - one_stddev = torch.ones( - self.descriptor.get_ntypes(), - self.descriptor.get_nsel(), - self.last_dim, - dtype=env.GLOBAL_PT_FLOAT_PRECISION, - device=env.DEVICE, - ) - if self.last_dim == 4: - radial_only = False - elif self.last_dim == 1: - radial_only = True - else: - raise ValueError( - "last_dim should be 1 for raial-only or 4 for full descriptor." - ) - for system in data: - coord, atype, box, natoms = ( - system["coord"], - system["atype"], - system["box"], - system["natoms"], - ) - ( - extended_coord, - extended_atype, - mapping, - nlist, - ) = extend_input_and_build_neighbor_list( - coord, - atype, - self.descriptor.get_rcut(), - self.descriptor.get_sel(), - mixed_types=self.descriptor.mixed_types(), - box=box, - ) - env_mat, _, _ = prod_env_mat( - extended_coord, - nlist, - atype, - zero_mean, - one_stddev, - self.descriptor.get_rcut(), - self.descriptor.get_rcut_smth(), - radial_only, - protection=self.descriptor.get_env_protection(), - ) - # apply excluded_types - exclude_mask = self.descriptor.emask(nlist, extended_atype) - env_mat *= exclude_mask.unsqueeze(-1) - # reshape to nframes * nloc at the atom level, - # so nframes/mixed_type do not matter - env_mat = env_mat.view( - coord.shape[0] * coord.shape[1], - self.descriptor.get_nsel(), - self.last_dim, - ) - atype = atype.view(coord.shape[0] * coord.shape[1]) - # (1, nloc) eq (ntypes, 1), so broadcast is possible - # shape: (ntypes, nloc) - type_idx = torch.eq( - atype.view(1, -1), - torch.arange( - self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 - ).view(-1, 1), - ) - if "pair_exclude_types" in system: - # shape: (1, nloc, nnei) - exclude_mask = PairExcludeMask( - self.descriptor.get_ntypes(), system["pair_exclude_types"] - )(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1) - # shape: (ntypes, nloc, nnei) - type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask) - for type_i in range(self.descriptor.get_ntypes()): - dd = env_mat[type_idx[type_i]] - dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4 - env_mats = {} - env_mats[f"r_{type_i}"] = dd[:, :1] - if self.last_dim == 4: - env_mats[f"a_{type_i}"] = dd[:, 1:] - yield self.compute_stat(env_mats) - - def get_hash(self) -> str: - """Get the hash of the environment matrix. - - Returns - ------- - str - The hash of the environment matrix. - """ - dscpt_type = "se_a" if self.last_dim == 4 else "se_r" - return get_hash( - { - "type": dscpt_type, - "ntypes": self.descriptor.get_ntypes(), - "rcut": round(self.descriptor.get_rcut(), 2), - "rcut_smth": round(self.descriptor.rcut_smth, 2), - "nsel": self.descriptor.get_nsel(), - "sel": self.descriptor.get_sel(), - "mixed_types": self.descriptor.mixed_types(), - } - ) - - def __call__(self) -> tuple[np.ndarray, np.ndarray]: - avgs = self.get_avg() - stds = self.get_std() - - all_davg = [] - all_dstd = [] - for type_i in range(self.descriptor.get_ntypes()): - if self.last_dim == 4: - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - stds[f"a_{type_i}"], - ] - ] - elif self.last_dim == 1: - davgunit = [[avgs[f"r_{type_i}"]]] - dstdunit = [ - [ - stds[f"r_{type_i}"], - ] - ] - davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) - dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) - all_davg.append(davg) - all_dstd.append(dstd) - mean = np.stack(all_davg) - stddev = np.stack(all_dstd) - return mean, stddev +__all__ = [ + "EnvMatStat", + "EnvMatStatSe", +] From c8952821facd576eb8e6f57bd98a899f07e6b4ff Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Jan 2026 19:26:43 +0800 Subject: [PATCH 02/10] sync from #5137 --- deepmd/dpmodel/utils/nlist.py | 72 +++++++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 11 deletions(-) diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index a43cf46403..eb5320bad1 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -117,7 +117,9 @@ def build_neighbor_list( assert list(diff.shape) == [batch_size, nloc, nall, 3] rr = xp.linalg.vector_norm(diff, axis=-1) # if central atom has two zero distances, sorting sometimes can not exclude itself - rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :] + rr -= xp.eye(nloc, nall, dtype=diff.dtype, device=array_api_compat.device(diff))[ + xp.newaxis, :, : + ] nlist = xp.argsort(rr, axis=-1) rr = xp.sort(rr, axis=-1) rr = rr[:, :, 1:] @@ -128,11 +130,26 @@ def build_neighbor_list( nlist = nlist[:, :, :nsel] else: rr = xp.concatenate( - [rr, xp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut], + [ + rr, + xp.ones( + [batch_size, nloc, nsel - nnei], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + + rcut, + ], axis=-1, ) nlist = xp.concatenate( - [nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + [ + nlist, + xp.ones( + [batch_size, nloc, nsel - nnei], + dtype=nlist.dtype, + device=array_api_compat.device(nlist), + ), + ], axis=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] @@ -218,7 +235,11 @@ def build_multiple_neighbor_list( return {} nb, nloc, nsel = nlist.shape if nsel < nsels[-1]: - pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype) + pad = -1 * xp.ones( + (nb, nloc, nsels[-1] - nsel), + dtype=nlist.dtype, + device=array_api_compat.device(nlist), + ) nlist = xp.concat([nlist, pad], axis=-1) nsel = nsels[-1] coord1 = xp.reshape(coord, (nb, -1, 3)) @@ -276,7 +297,12 @@ def extend_coord_with_ghosts( xp = array_api_compat.array_namespace(coord, atype) nf, nloc = atype.shape # int64 for index - aidx = xp.tile(xp.arange(nloc, dtype=xp.int64)[xp.newaxis, :], (nf, 1)) + aidx = xp.tile( + xp.arange(nloc, dtype=xp.int64, device=array_api_compat.device(atype))[ + xp.newaxis, : + ], + (nf, 1), + ) if cell is None: nall = nloc extend_coord = coord @@ -288,17 +314,41 @@ def extend_coord_with_ghosts( to_face = to_face_distance(cell) nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64) nbuff = xp.max(nbuff, axis=0) - xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=xp.int64) - yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=xp.int64) - zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=xp.int64) - xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :] + xi = xp.arange( + -int(nbuff[0]), + int(nbuff[0]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + yi = xp.arange( + -int(nbuff[1]), + int(nbuff[1]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + zi = xp.arange( + -int(nbuff[2]), + int(nbuff[2]) + 1, + 1, + dtype=xp.int64, + device=array_api_compat.device(coord), + ) + xyz = xp.linalg.outer( + xi, xp.asarray([1, 0, 0], device=array_api_compat.device(xi)) + )[:, xp.newaxis, xp.newaxis, :] xyz = ( xyz - + xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :] + + xp.linalg.outer( + yi, xp.asarray([0, 1, 0], device=array_api_compat.device(yi)) + )[xp.newaxis, :, xp.newaxis, :] ) xyz = ( xyz - + xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :] + + xp.linalg.outer( + zi, xp.asarray([0, 0, 1], device=array_api_compat.device(zi)) + )[xp.newaxis, xp.newaxis, :, :] ) xyz = xp.reshape(xyz, (-1, 3)) xyz = xp.astype(xyz, coord.dtype) From 0521725cd84649a2c0dc2b419eb63770446c61b0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 8 Jan 2026 19:37:02 +0800 Subject: [PATCH 03/10] fix other errors --- deepmd/dpmodel/utils/env_mat_stat.py | 10 ++++++++-- deepmd/dpmodel/utils/nlist.py | 2 +- deepmd/pt/utils/exclude_mask.py | 2 ++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index b8befa0087..ec7fed10b6 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -58,7 +58,7 @@ def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]: for kk, vv in env_mat.items(): xp = array_api_compat.array_namespace(vv) stats[kk] = StatItem( - number=vv.size, + number=array_api_compat.size(vv), sum=float(xp.sum(vv)), squared_sum=float(xp.sum(xp.square(vv))), ) @@ -104,6 +104,7 @@ def iter( self.last_dim, ), dtype=get_xp_precision(xp, "global"), + device=array_api_compat.device(data[0]["coord"]), ) one_stddev = xp.ones( ( @@ -112,6 +113,7 @@ def iter( self.last_dim, ), dtype=get_xp_precision(xp, "global"), + device=array_api_compat.device(data[0]["coord"]), ) if self.last_dim == 4: radial_only = False @@ -175,7 +177,11 @@ def iter( type_idx = xp.equal( xp.reshape(atype, (1, -1)), xp.reshape( - xp.arange(self.descriptor.get_ntypes(), dtype=xp.int32), + xp.arange( + self.descriptor.get_ntypes(), + dtype=xp.int32, + device=array_api_compat.device(atype), + ), (-1, 1), ), ) diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index eb5320bad1..55fb1a73f9 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -96,7 +96,7 @@ def build_neighbor_list( nall = coord.shape[1] // 3 # fill virtual atoms with large coords so they are not neighbors of any # real atom. - if coord.size > 0: + if array_api_compat.size(coord) > 0: xmax = xp.max(coord) + 2.0 * rcut else: xmax = 2.0 * rcut diff --git a/deepmd/pt/utils/exclude_mask.py b/deepmd/pt/utils/exclude_mask.py index cf39220f1b..451b32252c 100644 --- a/deepmd/pt/utils/exclude_mask.py +++ b/deepmd/pt/utils/exclude_mask.py @@ -147,3 +147,5 @@ def forward( type_ij = type_ij.view(nf, nloc * nnei) mask = self.type_mask[type_ij].view(nf, nloc, nnei).to(atype_ext.device) return mask + + build_type_exclude_mask = forward From b652f03a04e90bb4ecc60c02cab7212828a51c76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Jan 2026 11:39:41 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/env_mat_stat.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index ee2bf043d2..cfb5da0dea 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -6,7 +6,6 @@ EnvMatStatSe, ) - __all__ = [ "EnvMatStat", "EnvMatStatSe", From ad01332834bd66716840c941824290f7b4993881 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 9 Jan 2026 00:19:16 +0800 Subject: [PATCH 05/10] fix: improve remainder calculation in normalize_coord function --- deepmd/dpmodel/utils/region.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/region.py b/deepmd/dpmodel/utils/region.py index 6d8dfebf88..61786c33a0 100644 --- a/deepmd/dpmodel/utils/region.py +++ b/deepmd/dpmodel/utils/region.py @@ -74,7 +74,9 @@ def normalize_coord( """ xp = array_api_compat.array_namespace(coord, cell) icoord = phys2inter(coord, cell) - icoord = xp.remainder(icoord, xp.asarray(1.0)) + icoord = xp.remainder( + icoord, xp.ones((), dtype=icoord.dtype, device=array_api_compat.device(icoord)) + ) return inter2phys(icoord, cell) From e72375e4f96843d851bca290315e94aae5069675 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 10 Jan 2026 14:44:39 +0800 Subject: [PATCH 06/10] fix xp.ones --- deepmd/dpmodel/utils/exclude_mask.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 9d8f0c8572..4a65d33775 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -120,7 +120,16 @@ def build_type_exclude_mask( nall = atype_ext.shape[1] # add virtual atom of type ntypes. nf x nall+1 ae = xp.concat( - [atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1 + [ + atype_ext, + self.ntypes + * xp.ones( + [nf, 1], + dtype=atype_ext.dtype, + device=array_api_compat.device(atype_ext), + ), + ], + axis=-1, ) type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1) # nf x nloc x nnei From 4058f9b767c8360bda69d6cb27dfc20db27268cd Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 10 Jan 2026 14:54:10 +0800 Subject: [PATCH 07/10] fix IndexError: list index out of range --- deepmd/dpmodel/utils/env_mat_stat.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index ec7fed10b6..1f793fdb69 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -96,6 +96,9 @@ def iter( dict[str, StatItem] The statistics of the environment matrix. """ + if len(data) == 0: + # workaround to fix IndexError: list index out of range + return xp = array_api_compat.array_namespace(data[0]["coord"]) zero_mean = xp.zeros( ( From a46cffdb0e587d6070e1f936139b5cded0947f93 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 10 Jan 2026 15:52:28 +0800 Subject: [PATCH 08/10] fix PairExcludeMask --- deepmd/dpmodel/utils/env_mat_stat.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index 1f793fdb69..f23a24076d 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -98,7 +98,7 @@ def iter( """ if len(data) == 0: # workaround to fix IndexError: list index out of range - return + yield from () xp = array_api_compat.array_namespace(data[0]["coord"]) zero_mean = xp.zeros( ( @@ -189,11 +189,16 @@ def iter( ), ) if "pair_exclude_types" in system: + pair_exclude_mask = PairExcludeMask( + self.descriptor.get_ntypes(), system["pair_exclude_types"] + ) + pair_exclude_mask.type_mask = xp.asarray( + pair_exclude_mask.type_mask, + device=array_api_compat.device(atype), + ) # shape: (1, nloc, nnei) exclude_mask = xp.reshape( - PairExcludeMask( - self.descriptor.get_ntypes(), system["pair_exclude_types"] - ).build_type_exclude_mask(nlist, extended_atype), + pair_exclude_mask.build_type_exclude_mask(nlist, extended_atype), (1, coord.shape[0] * coord.shape[1], -1), ) # shape: (ntypes, nloc, nnei) From 4c62a2602555eec2b4048c76ad995f96de6a5ec3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 10 Jan 2026 17:01:20 +0800 Subject: [PATCH 09/10] add return --- deepmd/dpmodel/utils/env_mat_stat.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index f23a24076d..68c1ca8129 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -99,6 +99,7 @@ def iter( if len(data) == 0: # workaround to fix IndexError: list index out of range yield from () + return xp = array_api_compat.array_namespace(data[0]["coord"]) zero_mean = xp.zeros( ( From 21b4c06141a75155b48ade085afd4be70f9097b6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 10 Jan 2026 23:54:22 +0800 Subject: [PATCH 10/10] fix: validate last_dim in EnvMatStatSe for radial-only and full descriptor --- deepmd/dpmodel/utils/env_mat_stat.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index 68c1ca8129..37a69ea1b1 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -96,6 +96,14 @@ def iter( dict[str, StatItem] The statistics of the environment matrix. """ + if self.last_dim == 4: + radial_only = False + elif self.last_dim == 1: + radial_only = True + else: + raise ValueError( + "last_dim should be 1 for raial-only or 4 for full descriptor." + ) if len(data) == 0: # workaround to fix IndexError: list index out of range yield from () @@ -119,14 +127,6 @@ def iter( dtype=get_xp_precision(xp, "global"), device=array_api_compat.device(data[0]["coord"]), ) - if self.last_dim == 4: - radial_only = False - elif self.last_dim == 1: - radial_only = True - else: - raise ValueError( - "last_dim should be 1 for raial-only or 4 for full descriptor." - ) for system in data: coord, atype, box, natoms = ( system["coord"],