|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
2 | | -from collections.abc import ( |
3 | | - Iterator, |
4 | | -) |
5 | | -from typing import ( |
6 | | - TYPE_CHECKING, |
7 | | -) |
8 | 2 |
|
9 | | -import numpy as np |
10 | | -import torch |
11 | 3 |
|
12 | | -from deepmd.common import ( |
13 | | - get_hash, |
14 | | -) |
15 | | -from deepmd.pt.model.descriptor.env_mat import ( |
16 | | - prod_env_mat, |
17 | | -) |
18 | | -from deepmd.pt.utils import ( |
19 | | - env, |
20 | | -) |
21 | | -from deepmd.pt.utils.exclude_mask import ( |
22 | | - PairExcludeMask, |
23 | | -) |
24 | | -from deepmd.pt.utils.nlist import ( |
25 | | - extend_input_and_build_neighbor_list, |
| 4 | +from deepmd.dpmodel.utils.env_mat_stat import ( |
| 5 | + EnvMatStat, |
| 6 | + EnvMatStatSe, |
26 | 7 | ) |
27 | | -from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat |
28 | | -from deepmd.utils.env_mat_stat import ( |
29 | | - StatItem, |
30 | | -) |
31 | | - |
32 | | -if TYPE_CHECKING: |
33 | | - from deepmd.pt.model.descriptor import ( |
34 | | - DescriptorBlock, |
35 | | - ) |
36 | | - |
37 | | - |
38 | | -class EnvMatStat(BaseEnvMatStat): |
39 | | - def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]: |
40 | | - """Compute the statistics of the environment matrix for a single system. |
41 | | -
|
42 | | - Parameters |
43 | | - ---------- |
44 | | - env_mat : torch.Tensor |
45 | | - The environment matrix. |
46 | | -
|
47 | | - Returns |
48 | | - ------- |
49 | | - dict[str, StatItem] |
50 | | - The statistics of the environment matrix. |
51 | | - """ |
52 | | - stats = {} |
53 | | - for kk, vv in env_mat.items(): |
54 | | - stats[kk] = StatItem( |
55 | | - number=vv.numel(), |
56 | | - sum=vv.sum().item(), |
57 | | - squared_sum=torch.square(vv).sum().item(), |
58 | | - ) |
59 | | - return stats |
60 | | - |
61 | | - |
62 | | -class EnvMatStatSe(EnvMatStat): |
63 | | - """Environmental matrix statistics for the se_a/se_r environmental matrix. |
64 | | -
|
65 | | - Parameters |
66 | | - ---------- |
67 | | - descriptor : DescriptorBlock |
68 | | - The descriptor of the model. |
69 | | - """ |
70 | | - |
71 | | - def __init__(self, descriptor: "DescriptorBlock") -> None: |
72 | | - super().__init__() |
73 | | - self.descriptor = descriptor |
74 | | - self.last_dim = ( |
75 | | - self.descriptor.ndescrpt // self.descriptor.nnei |
76 | | - ) # se_r=1, se_a=4 |
77 | | - |
78 | | - def iter( |
79 | | - self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]] |
80 | | - ) -> Iterator[dict[str, StatItem]]: |
81 | | - """Get the iterator of the environment matrix. |
82 | | -
|
83 | | - Parameters |
84 | | - ---------- |
85 | | - data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]] |
86 | | - The data. |
87 | | -
|
88 | | - Yields |
89 | | - ------ |
90 | | - dict[str, StatItem] |
91 | | - The statistics of the environment matrix. |
92 | | - """ |
93 | | - zero_mean = torch.zeros( |
94 | | - self.descriptor.get_ntypes(), |
95 | | - self.descriptor.get_nsel(), |
96 | | - self.last_dim, |
97 | | - dtype=env.GLOBAL_PT_FLOAT_PRECISION, |
98 | | - device=env.DEVICE, |
99 | | - ) |
100 | | - one_stddev = torch.ones( |
101 | | - self.descriptor.get_ntypes(), |
102 | | - self.descriptor.get_nsel(), |
103 | | - self.last_dim, |
104 | | - dtype=env.GLOBAL_PT_FLOAT_PRECISION, |
105 | | - device=env.DEVICE, |
106 | | - ) |
107 | | - if self.last_dim == 4: |
108 | | - radial_only = False |
109 | | - elif self.last_dim == 1: |
110 | | - radial_only = True |
111 | | - else: |
112 | | - raise ValueError( |
113 | | - "last_dim should be 1 for raial-only or 4 for full descriptor." |
114 | | - ) |
115 | | - for system in data: |
116 | | - coord, atype, box, natoms = ( |
117 | | - system["coord"], |
118 | | - system["atype"], |
119 | | - system["box"], |
120 | | - system["natoms"], |
121 | | - ) |
122 | | - ( |
123 | | - extended_coord, |
124 | | - extended_atype, |
125 | | - mapping, |
126 | | - nlist, |
127 | | - ) = extend_input_and_build_neighbor_list( |
128 | | - coord, |
129 | | - atype, |
130 | | - self.descriptor.get_rcut(), |
131 | | - self.descriptor.get_sel(), |
132 | | - mixed_types=self.descriptor.mixed_types(), |
133 | | - box=box, |
134 | | - ) |
135 | | - env_mat, _, _ = prod_env_mat( |
136 | | - extended_coord, |
137 | | - nlist, |
138 | | - atype, |
139 | | - zero_mean, |
140 | | - one_stddev, |
141 | | - self.descriptor.get_rcut(), |
142 | | - self.descriptor.get_rcut_smth(), |
143 | | - radial_only, |
144 | | - protection=self.descriptor.get_env_protection(), |
145 | | - ) |
146 | | - # apply excluded_types |
147 | | - exclude_mask = self.descriptor.emask(nlist, extended_atype) |
148 | | - env_mat *= exclude_mask.unsqueeze(-1) |
149 | | - # reshape to nframes * nloc at the atom level, |
150 | | - # so nframes/mixed_type do not matter |
151 | | - env_mat = env_mat.view( |
152 | | - coord.shape[0] * coord.shape[1], |
153 | | - self.descriptor.get_nsel(), |
154 | | - self.last_dim, |
155 | | - ) |
156 | | - atype = atype.view(coord.shape[0] * coord.shape[1]) |
157 | | - # (1, nloc) eq (ntypes, 1), so broadcast is possible |
158 | | - # shape: (ntypes, nloc) |
159 | | - type_idx = torch.eq( |
160 | | - atype.view(1, -1), |
161 | | - torch.arange( |
162 | | - self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 |
163 | | - ).view(-1, 1), |
164 | | - ) |
165 | | - if "pair_exclude_types" in system: |
166 | | - # shape: (1, nloc, nnei) |
167 | | - exclude_mask = PairExcludeMask( |
168 | | - self.descriptor.get_ntypes(), system["pair_exclude_types"] |
169 | | - )(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1) |
170 | | - # shape: (ntypes, nloc, nnei) |
171 | | - type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask) |
172 | | - for type_i in range(self.descriptor.get_ntypes()): |
173 | | - dd = env_mat[type_idx[type_i]] |
174 | | - dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4 |
175 | | - env_mats = {} |
176 | | - env_mats[f"r_{type_i}"] = dd[:, :1] |
177 | | - if self.last_dim == 4: |
178 | | - env_mats[f"a_{type_i}"] = dd[:, 1:] |
179 | | - yield self.compute_stat(env_mats) |
180 | | - |
181 | | - def get_hash(self) -> str: |
182 | | - """Get the hash of the environment matrix. |
183 | | -
|
184 | | - Returns |
185 | | - ------- |
186 | | - str |
187 | | - The hash of the environment matrix. |
188 | | - """ |
189 | | - dscpt_type = "se_a" if self.last_dim == 4 else "se_r" |
190 | | - return get_hash( |
191 | | - { |
192 | | - "type": dscpt_type, |
193 | | - "ntypes": self.descriptor.get_ntypes(), |
194 | | - "rcut": round(self.descriptor.get_rcut(), 2), |
195 | | - "rcut_smth": round(self.descriptor.rcut_smth, 2), |
196 | | - "nsel": self.descriptor.get_nsel(), |
197 | | - "sel": self.descriptor.get_sel(), |
198 | | - "mixed_types": self.descriptor.mixed_types(), |
199 | | - } |
200 | | - ) |
201 | | - |
202 | | - def __call__(self) -> tuple[np.ndarray, np.ndarray]: |
203 | | - avgs = self.get_avg() |
204 | | - stds = self.get_std() |
205 | | - |
206 | | - all_davg = [] |
207 | | - all_dstd = [] |
208 | | - |
209 | | - for type_i in range(self.descriptor.get_ntypes()): |
210 | | - if self.last_dim == 4: |
211 | | - davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] |
212 | | - dstdunit = [ |
213 | | - [ |
214 | | - stds[f"r_{type_i}"], |
215 | | - stds[f"a_{type_i}"], |
216 | | - stds[f"a_{type_i}"], |
217 | | - stds[f"a_{type_i}"], |
218 | | - ] |
219 | | - ] |
220 | | - elif self.last_dim == 1: |
221 | | - davgunit = [[avgs[f"r_{type_i}"]]] |
222 | | - dstdunit = [ |
223 | | - [ |
224 | | - stds[f"r_{type_i}"], |
225 | | - ] |
226 | | - ] |
227 | | - davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) |
228 | | - dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) |
229 | | - all_davg.append(davg) |
230 | | - all_dstd.append(dstd) |
231 | 8 |
|
232 | | - mean = np.stack(all_davg) |
233 | | - stddev = np.stack(all_dstd) |
234 | | - return mean, stddev |
| 9 | +__all__ = [ |
| 10 | + "EnvMatStat", |
| 11 | + "EnvMatStatSe", |
| 12 | +] |
0 commit comments