-
Notifications
You must be signed in to change notification settings - Fork 585
refactor(pt): reuse dpmodel EnvMatStat #5139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0e5b449
refactor(pt): reuse env mat stat
njzjz c895282
sync from #5137
njzjz 0521725
fix other errors
njzjz b652f03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ad01332
fix: improve remainder calculation in normalize_coord function
njzjz c690b5d
Merge branch 'pt-reuse-env-mat-stat' of https://github.com/njzjz/deep…
njzjz bc70c42
Merge branch 'master' into pt-reuse-env-mat-stat
njzjz e72375e
fix xp.ones
njzjz 4058f9b
fix IndexError: list index out of range
njzjz a46cffd
fix PairExcludeMask
njzjz 4c62a26
add return
njzjz 21b4c06
fix: validate last_dim in EnvMatStatSe for radial-only and full descr…
njzjz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,234 +1,12 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| from collections.abc import ( | ||
| Iterator, | ||
| ) | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| ) | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from deepmd.common import ( | ||
| get_hash, | ||
| ) | ||
| from deepmd.pt.model.descriptor.env_mat import ( | ||
| prod_env_mat, | ||
| ) | ||
| from deepmd.pt.utils import ( | ||
| env, | ||
| ) | ||
| from deepmd.pt.utils.exclude_mask import ( | ||
| PairExcludeMask, | ||
| ) | ||
| from deepmd.pt.utils.nlist import ( | ||
| extend_input_and_build_neighbor_list, | ||
| from deepmd.dpmodel.utils.env_mat_stat import ( | ||
| EnvMatStat, | ||
| EnvMatStatSe, | ||
| ) | ||
| from deepmd.utils.env_mat_stat import EnvMatStat as BaseEnvMatStat | ||
| from deepmd.utils.env_mat_stat import ( | ||
| StatItem, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from deepmd.pt.model.descriptor import ( | ||
| DescriptorBlock, | ||
| ) | ||
|
|
||
|
|
||
| class EnvMatStat(BaseEnvMatStat): | ||
| def compute_stat(self, env_mat: dict[str, torch.Tensor]) -> dict[str, StatItem]: | ||
| """Compute the statistics of the environment matrix for a single system. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| env_mat : torch.Tensor | ||
| The environment matrix. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict[str, StatItem] | ||
| The statistics of the environment matrix. | ||
| """ | ||
| stats = {} | ||
| for kk, vv in env_mat.items(): | ||
| stats[kk] = StatItem( | ||
| number=vv.numel(), | ||
| sum=vv.sum().item(), | ||
| squared_sum=torch.square(vv).sum().item(), | ||
| ) | ||
| return stats | ||
|
|
||
|
|
||
| class EnvMatStatSe(EnvMatStat): | ||
| """Environmental matrix statistics for the se_a/se_r environmental matrix. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| descriptor : DescriptorBlock | ||
| The descriptor of the model. | ||
| """ | ||
|
|
||
| def __init__(self, descriptor: "DescriptorBlock") -> None: | ||
| super().__init__() | ||
| self.descriptor = descriptor | ||
| self.last_dim = ( | ||
| self.descriptor.ndescrpt // self.descriptor.nnei | ||
| ) # se_r=1, se_a=4 | ||
|
|
||
| def iter( | ||
| self, data: list[dict[str, torch.Tensor | list[tuple[int, int]]]] | ||
| ) -> Iterator[dict[str, StatItem]]: | ||
| """Get the iterator of the environment matrix. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data : list[dict[str, Union[torch.Tensor, list[tuple[int, int]]]]] | ||
| The data. | ||
|
|
||
| Yields | ||
| ------ | ||
| dict[str, StatItem] | ||
| The statistics of the environment matrix. | ||
| """ | ||
| zero_mean = torch.zeros( | ||
| self.descriptor.get_ntypes(), | ||
| self.descriptor.get_nsel(), | ||
| self.last_dim, | ||
| dtype=env.GLOBAL_PT_FLOAT_PRECISION, | ||
| device=env.DEVICE, | ||
| ) | ||
| one_stddev = torch.ones( | ||
| self.descriptor.get_ntypes(), | ||
| self.descriptor.get_nsel(), | ||
| self.last_dim, | ||
| dtype=env.GLOBAL_PT_FLOAT_PRECISION, | ||
| device=env.DEVICE, | ||
| ) | ||
| if self.last_dim == 4: | ||
| radial_only = False | ||
| elif self.last_dim == 1: | ||
| radial_only = True | ||
| else: | ||
| raise ValueError( | ||
| "last_dim should be 1 for raial-only or 4 for full descriptor." | ||
| ) | ||
| for system in data: | ||
| coord, atype, box, natoms = ( | ||
| system["coord"], | ||
| system["atype"], | ||
| system["box"], | ||
| system["natoms"], | ||
| ) | ||
| ( | ||
| extended_coord, | ||
| extended_atype, | ||
| mapping, | ||
| nlist, | ||
| ) = extend_input_and_build_neighbor_list( | ||
| coord, | ||
| atype, | ||
| self.descriptor.get_rcut(), | ||
| self.descriptor.get_sel(), | ||
| mixed_types=self.descriptor.mixed_types(), | ||
| box=box, | ||
| ) | ||
| env_mat, _, _ = prod_env_mat( | ||
| extended_coord, | ||
| nlist, | ||
| atype, | ||
| zero_mean, | ||
| one_stddev, | ||
| self.descriptor.get_rcut(), | ||
| self.descriptor.get_rcut_smth(), | ||
| radial_only, | ||
| protection=self.descriptor.get_env_protection(), | ||
| ) | ||
| # apply excluded_types | ||
| exclude_mask = self.descriptor.emask(nlist, extended_atype) | ||
| env_mat *= exclude_mask.unsqueeze(-1) | ||
| # reshape to nframes * nloc at the atom level, | ||
| # so nframes/mixed_type do not matter | ||
| env_mat = env_mat.view( | ||
| coord.shape[0] * coord.shape[1], | ||
| self.descriptor.get_nsel(), | ||
| self.last_dim, | ||
| ) | ||
| atype = atype.view(coord.shape[0] * coord.shape[1]) | ||
| # (1, nloc) eq (ntypes, 1), so broadcast is possible | ||
| # shape: (ntypes, nloc) | ||
| type_idx = torch.eq( | ||
| atype.view(1, -1), | ||
| torch.arange( | ||
| self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 | ||
| ).view(-1, 1), | ||
| ) | ||
| if "pair_exclude_types" in system: | ||
| # shape: (1, nloc, nnei) | ||
| exclude_mask = PairExcludeMask( | ||
| self.descriptor.get_ntypes(), system["pair_exclude_types"] | ||
| )(nlist, extended_atype).view(1, coord.shape[0] * coord.shape[1], -1) | ||
| # shape: (ntypes, nloc, nnei) | ||
| type_idx = torch.logical_and(type_idx.unsqueeze(-1), exclude_mask) | ||
| for type_i in range(self.descriptor.get_ntypes()): | ||
| dd = env_mat[type_idx[type_i]] | ||
| dd = dd.reshape([-1, self.last_dim]) # typen_atoms * unmasked_nnei, 4 | ||
| env_mats = {} | ||
| env_mats[f"r_{type_i}"] = dd[:, :1] | ||
| if self.last_dim == 4: | ||
| env_mats[f"a_{type_i}"] = dd[:, 1:] | ||
| yield self.compute_stat(env_mats) | ||
|
|
||
| def get_hash(self) -> str: | ||
| """Get the hash of the environment matrix. | ||
|
|
||
| Returns | ||
| ------- | ||
| str | ||
| The hash of the environment matrix. | ||
| """ | ||
| dscpt_type = "se_a" if self.last_dim == 4 else "se_r" | ||
| return get_hash( | ||
| { | ||
| "type": dscpt_type, | ||
| "ntypes": self.descriptor.get_ntypes(), | ||
| "rcut": round(self.descriptor.get_rcut(), 2), | ||
| "rcut_smth": round(self.descriptor.rcut_smth, 2), | ||
| "nsel": self.descriptor.get_nsel(), | ||
| "sel": self.descriptor.get_sel(), | ||
| "mixed_types": self.descriptor.mixed_types(), | ||
| } | ||
| ) | ||
|
|
||
| def __call__(self) -> tuple[np.ndarray, np.ndarray]: | ||
| avgs = self.get_avg() | ||
| stds = self.get_std() | ||
|
|
||
| all_davg = [] | ||
| all_dstd = [] | ||
|
|
||
| for type_i in range(self.descriptor.get_ntypes()): | ||
| if self.last_dim == 4: | ||
| davgunit = [[avgs[f"r_{type_i}"], 0, 0, 0]] | ||
| dstdunit = [ | ||
| [ | ||
| stds[f"r_{type_i}"], | ||
| stds[f"a_{type_i}"], | ||
| stds[f"a_{type_i}"], | ||
| stds[f"a_{type_i}"], | ||
| ] | ||
| ] | ||
| elif self.last_dim == 1: | ||
| davgunit = [[avgs[f"r_{type_i}"]]] | ||
| dstdunit = [ | ||
| [ | ||
| stds[f"r_{type_i}"], | ||
| ] | ||
| ] | ||
| davg = np.tile(davgunit, [self.descriptor.get_nsel(), 1]) | ||
| dstd = np.tile(dstdunit, [self.descriptor.get_nsel(), 1]) | ||
| all_davg.append(davg) | ||
| all_dstd.append(dstd) | ||
|
|
||
| mean = np.stack(all_davg) | ||
| stddev = np.stack(all_dstd) | ||
| return mean, stddev | ||
| __all__ = [ | ||
| "EnvMatStat", | ||
| "EnvMatStatSe", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.