Skip to content

Commit b5d587c

Browse files
refactor(pt): reuse dpmodel EnvMatStat (#5139)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Consolidated duplicated utilities by re-exporting shared implementations to simplify maintenance. * **Bug Fixes** * Improved device/dtype consistency for array operations and masks to ensure tensors are placed and computed correctly across backends. * Fixed edge-case handling to avoid errors on empty data and improved reuse of exclusion masks for correctness. * **New Features** * Added a convenient public alias for mask-building functionality. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2802ce7 commit b5d587c

File tree

6 files changed

+51
-245
lines changed

6 files changed

+51
-245
lines changed

deepmd/dpmodel/utils/env_mat_stat.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]:
5858
for kk, vv in env_mat.items():
5959
xp = array_api_compat.array_namespace(vv)
6060
stats[kk] = StatItem(
61-
number=vv.size,
61+
number=array_api_compat.size(vv),
6262
sum=float(xp.sum(vv)),
6363
squared_sum=float(xp.sum(xp.square(vv))),
6464
)
@@ -96,6 +96,18 @@ def iter(
9696
dict[str, StatItem]
9797
The statistics of the environment matrix.
9898
"""
99+
if self.last_dim == 4:
100+
radial_only = False
101+
elif self.last_dim == 1:
102+
radial_only = True
103+
else:
104+
raise ValueError(
105+
"last_dim should be 1 for raial-only or 4 for full descriptor."
106+
)
107+
if len(data) == 0:
108+
# workaround to fix IndexError: list index out of range
109+
yield from ()
110+
return
99111
xp = array_api_compat.array_namespace(data[0]["coord"])
100112
zero_mean = xp.zeros(
101113
(
@@ -104,6 +116,7 @@ def iter(
104116
self.last_dim,
105117
),
106118
dtype=get_xp_precision(xp, "global"),
119+
device=array_api_compat.device(data[0]["coord"]),
107120
)
108121
one_stddev = xp.ones(
109122
(
@@ -112,15 +125,8 @@ def iter(
112125
self.last_dim,
113126
),
114127
dtype=get_xp_precision(xp, "global"),
128+
device=array_api_compat.device(data[0]["coord"]),
115129
)
116-
if self.last_dim == 4:
117-
radial_only = False
118-
elif self.last_dim == 1:
119-
radial_only = True
120-
else:
121-
raise ValueError(
122-
"last_dim should be 1 for raial-only or 4 for full descriptor."
123-
)
124130
for system in data:
125131
coord, atype, box, natoms = (
126132
system["coord"],
@@ -175,16 +181,25 @@ def iter(
175181
type_idx = xp.equal(
176182
xp.reshape(atype, (1, -1)),
177183
xp.reshape(
178-
xp.arange(self.descriptor.get_ntypes(), dtype=xp.int32),
184+
xp.arange(
185+
self.descriptor.get_ntypes(),
186+
dtype=xp.int32,
187+
device=array_api_compat.device(atype),
188+
),
179189
(-1, 1),
180190
),
181191
)
182192
if "pair_exclude_types" in system:
193+
pair_exclude_mask = PairExcludeMask(
194+
self.descriptor.get_ntypes(), system["pair_exclude_types"]
195+
)
196+
pair_exclude_mask.type_mask = xp.asarray(
197+
pair_exclude_mask.type_mask,
198+
device=array_api_compat.device(atype),
199+
)
183200
# shape: (1, nloc, nnei)
184201
exclude_mask = xp.reshape(
185-
PairExcludeMask(
186-
self.descriptor.get_ntypes(), system["pair_exclude_types"]
187-
).build_type_exclude_mask(nlist, extended_atype),
202+
pair_exclude_mask.build_type_exclude_mask(nlist, extended_atype),
188203
(1, coord.shape[0] * coord.shape[1], -1),
189204
)
190205
# shape: (ntypes, nloc, nnei)

deepmd/dpmodel/utils/exclude_mask.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,16 @@ def build_type_exclude_mask(
120120
nall = atype_ext.shape[1]
121121
# add virtual atom of type ntypes. nf x nall+1
122122
ae = xp.concat(
123-
[atype_ext, self.ntypes * xp.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
123+
[
124+
atype_ext,
125+
self.ntypes
126+
* xp.ones(
127+
[nf, 1],
128+
dtype=atype_ext.dtype,
129+
device=array_api_compat.device(atype_ext),
130+
),
131+
],
132+
axis=-1,
124133
)
125134
type_i = xp.reshape(atype_ext[:, :nloc], (nf, nloc)) * (self.ntypes + 1)
126135
# nf x nloc x nnei

deepmd/dpmodel/utils/nlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def build_neighbor_list(
9696
nall = coord.shape[1] // 3
9797
# fill virtual atoms with large coords so they are not neighbors of any
9898
# real atom.
99-
if coord.size > 0:
99+
if array_api_compat.size(coord) > 0:
100100
xmax = xp.max(coord) + 2.0 * rcut
101101
else:
102102
xmax = 2.0 * rcut

deepmd/dpmodel/utils/region.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def normalize_coord(
7474
"""
7575
xp = array_api_compat.array_namespace(coord, cell)
7676
icoord = phys2inter(coord, cell)
77-
icoord = xp.remainder(icoord, xp.asarray(1.0))
77+
icoord = xp.remainder(
78+
icoord, xp.ones((), dtype=icoord.dtype, device=array_api_compat.device(icoord))
79+
)
7880
return inter2phys(icoord, cell)
7981

8082

deepmd/pt/utils/env_mat_stat.py

Lines changed: 7 additions & 229 deletions
Original file line numberDiff line numberDiff line change
@@ -1,234 +1,12 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
from collections.abc import (
3-
Iterator,
4-
)
5-
from typing import (
6-
TYPE_CHECKING,
7-
)
82

9-
import numpy as np
10-
import torch
113

12-
from deepmd.common import (
13-
get_hash,
14-
)
15-
from deepmd.pt.model.descriptor.env_mat import (
16-
prod_env_mat,
17-
)
18-
from deepmd.pt.utils import (
19-
env,
20-
)
21-
from deepmd.pt.utils.exclude_mask import (
22-
PairExcludeMask,
23-
)
24-
from deepmd.pt.utils.nlist import (
25-
extend_input_and_build_neighbor_list,
4+
from deepmd.dpmodel.utils.env_mat_stat import (
5+
EnvMatStat,
6+
EnvMatStatSe,
267
)
27-
from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat
28-
from deepmd.utils.env_mat_stat import (
29-
StatItem,
30-
)
31-
32-
if TYPE_CHECKING:
33-
from deepmd.pt.model.descriptor import (
34-
DescriptorBlock,
35-
)
36-
37-
38-
class EnvMatStat(BaseEnvMatStat):
39-
def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]:
40-
"""Compute the statistics of the environment matrix for a single system.
41-
42-
Parameters
43-
----------
44-
env_mat : torch.Tensor
45-
The environment matrix.
46-
47-
Returns
48-
-------
49-
dict[str, StatItem]
50-
The statistics of the environment matrix.
51-
"""
52-
stats = {}
53-
for kk, vv in env_mat.items():
54-
stats[kk] = StatItem(
55-
number=vv.numel(),
56-
sum=vv.sum().item(),
57-
squared_sum=torch.square(vv).sum().item(),
58-
)
59-
return stats
60-
61-
62-
class EnvMatStatSe(EnvMatStat):
63-
"""Environmental matrix statistics for the se_a/se_r environmental matrix.
64-
65-
Parameters
66-
----------
67-
descriptor : DescriptorBlock
68-
The descriptor of the model.
69-
"""
70-
71-
def __init__(self, descriptor: "DescriptorBlock") -> None:
72-
super().__init__()
73-
self.descriptor = descriptor
74-
self.last_dim = (
75-
self.descriptor.ndescrpt // self.descriptor.nnei
76-
) # se_r=1, se_a=4
77-
78-
def iter(
79-
self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]]
80-
) -> Iterator[dict[str, StatItem]]:
81-
"""Get the iterator of the environment matrix.
82-
83-
Parameters
84-
----------
85-
data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]]
86-
The data.
87-
88-
Yields
89-
------
90-
dict[str, StatItem]
91-
The statistics of the environment matrix.
92-
"""
93-
zero_mean = torch.zeros(
94-
self.descriptor.get_ntypes(),
95-
self.descriptor.get_nsel(),
96-
self.last_dim,
97-
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
98-
device=env.DEVICE,
99-
)
100-
one_stddev = torch.ones(
101-
self.descriptor.get_ntypes(),
102-
self.descriptor.get_nsel(),
103-
self.last_dim,
104-
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
105-
device=env.DEVICE,
106-
)
107-
if self.last_dim == 4:
108-
radial_only = False
109-
elif self.last_dim == 1:
110-
radial_only = True
111-
else:
112-
raise ValueError(
113-
"last_dim should be 1 for raial-only or 4 for full descriptor."
114-
)
115-
for system in data:
116-
coord, atype, box, natoms = (
117-
system["coord"],
118-
system["atype"],
119-
system["box"],
120-
system["natoms"],
121-
)
122-
(
123-
extended_coord,
124-
extended_atype,
125-
mapping,
126-
nlist,
127-
) = extend_input_and_build_neighbor_list(
128-
coord,
129-
atype,
130-
self.descriptor.get_rcut(),
131-
self.descriptor.get_sel(),
132-
mixed_types=self.descriptor.mixed_types(),
133-
box=box,
134-
)
135-
env_mat, _, _ = prod_env_mat(
136-
extended_coord,
137-
nlist,
138-
atype,
139-
zero_mean,
140-
one_stddev,
141-
self.descriptor.get_rcut(),
142-
self.descriptor.get_rcut_smth(),
143-
radial_only,
144-
protection=self.descriptor.get_env_protection(),
145-
)
146-
# apply excluded_types
147-
exclude_mask = self.descriptor.emask(nlist, extended_atype)
148-
env_mat *= exclude_mask.unsqueeze(-1)
149-
# reshape to nframes * nloc at the atom level,
150-
# so nframes/mixed_type do not matter
151-
env_mat = env_mat.view(
152-
coord.shape[0] * coord.shape[1],
153-
self.descriptor.get_nsel(),
154-
self.last_dim,
155-
)
156-
atype = atype.view(coord.shape[0] * coord.shape[1])
157-
# (1, nloc) eq (ntypes, 1), so broadcast is possible
158-
# shape: (ntypes, nloc)
159-
type_idx = torch.eq(
160-
atype.view(1, -1),
161-
torch.arange(
162-
self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32
163-
).view(-1, 1),
164-
)
165-
if "pair_exclude_types" in system:
166-
# shape: (1, nloc, nnei)
167-
exclude_mask = PairExcludeMask(
168-
self.descriptor.get_ntypes(), system["pair_exclude_types"]
169-
)(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1)
170-
# shape: (ntypes, nloc, nnei)
171-
type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask)
172-
for type_i in range(self.descriptor.get_ntypes()):
173-
dd = env_mat[type_idx[type_i]]
174-
dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4
175-
env_mats = {}
176-
env_mats[f"r_{type_i}"] = dd[:, :1]
177-
if self.last_dim == 4:
178-
env_mats[f"a_{type_i}"] = dd[:, 1:]
179-
yield self.compute_stat(env_mats)
180-
181-
def get_hash(self) -> str:
182-
"""Get the hash of the environment matrix.
183-
184-
Returns
185-
-------
186-
str
187-
The hash of the environment matrix.
188-
"""
189-
dscpt_type = "se_a" if self.last_dim == 4 else "se_r"
190-
return get_hash(
191-
{
192-
"type": dscpt_type,
193-
"ntypes": self.descriptor.get_ntypes(),
194-
"rcut": round(self.descriptor.get_rcut(), 2),
195-
"rcut_smth": round(self.descriptor.rcut_smth, 2),
196-
"nsel": self.descriptor.get_nsel(),
197-
"sel": self.descriptor.get_sel(),
198-
"mixed_types": self.descriptor.mixed_types(),
199-
}
200-
)
201-
202-
def __call__(self) -> tuple[np.ndarray, np.ndarray]:
203-
avgs = self.get_avg()
204-
stds = self.get_std()
205-
206-
all_davg = []
207-
all_dstd = []
208-
209-
for type_i in range(self.descriptor.get_ntypes()):
210-
if self.last_dim == 4:
211-
davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]]
212-
dstdunit = [
213-
[
214-
stds[f"r_{type_i}"],
215-
stds[f"a_{type_i}"],
216-
stds[f"a_{type_i}"],
217-
stds[f"a_{type_i}"],
218-
]
219-
]
220-
elif self.last_dim == 1:
221-
davgunit = [[avgs[f"r_{type_i}"]]]
222-
dstdunit = [
223-
[
224-
stds[f"r_{type_i}"],
225-
]
226-
]
227-
davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1])
228-
dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1])
229-
all_davg.append(davg)
230-
all_dstd.append(dstd)
2318

232-
mean = np.stack(all_davg)
233-
stddev = np.stack(all_dstd)
234-
return mean, stddev
9+
__all__ = [
10+
"EnvMatStat",
11+
"EnvMatStatSe",
12+
]

deepmd/pt/utils/exclude_mask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,5 @@ def forward(
147147
type_ij = type_ij.view(nf, nloc * nnei)
148148
mask = self.type_mask[type_ij].view(nf, nloc, nnei).to(atype_ext.device)
149149
return mask
150+
151+
build_type_exclude_mask = forward

0 commit comments

Comments
 (0)