-
Notifications
You must be signed in to change notification settings - Fork 569
Test paral auto improve #4900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Test paral auto improve #4900
Changes from all commits
8118231
29bedd3
36c8311
7e07126
5103f5d
90bd05b
d264fae
6528a85
3a6438e
3b09f93
318dab0
d236285
7ce9af1
78fe1b8
35783bb
95c1377
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||||||||||||||||||||||||||
import paddle | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
from deepmd.pd.utils.preprocess import ( | ||||||||||||||||||||||||||||||
compute_exp_sw, | ||||||||||||||||||||||||||||||
compute_smooth_weight, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -14,6 +15,7 @@ def _make_env_mat( | |||||||||||||||||||||||||||||
ruct_smth: float, | ||||||||||||||||||||||||||||||
radial_only: bool = False, | ||||||||||||||||||||||||||||||
protection: float = 0.0, | ||||||||||||||||||||||||||||||
use_exp_switch: bool = False, | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
"""Make smooth environment matrix.""" | ||||||||||||||||||||||||||||||
bsz, natoms, nnei = nlist.shape | ||||||||||||||||||||||||||||||
|
@@ -24,15 +26,20 @@ def _make_env_mat( | |||||||||||||||||||||||||||||
nlist = paddle.where(mask, nlist, nall - 1) | ||||||||||||||||||||||||||||||
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3]) | ||||||||||||||||||||||||||||||
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3]) | ||||||||||||||||||||||||||||||
coord_r = paddle.take_along_axis(coord, axis=1, indices=index) | ||||||||||||||||||||||||||||||
coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1) | ||||||||||||||||||||||||||||||
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index) | ||||||||||||||||||||||||||||||
coord_r = coord_r.reshape([bsz, natoms, nnei, 3]) | ||||||||||||||||||||||||||||||
Comment on lines
26
to
31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Padded sentinel is never indexed; nlist still redirects invalid entries to nall-1. You append one extra coordinate (coord_pad has length nall+1), but invalid nlist entries are replaced with nall - 1, which points to the last real atom, not the sentinel at index nall. This makes the padding ineffective. - nlist = paddle.where(mask, nlist, nall - 1)
+ # Redirect masked neighbors to the padded sentinel at index nall
+ nlist = paddle.where(mask, nlist, paddle.full_like(nlist, nall))
coord_l = coord[:, :natoms].reshape([bsz, -1, 1, 3])
index = nlist.reshape([bsz, -1]).unsqueeze(-1).expand([-1, -1, 3])
- coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
+ coord_pad = paddle.concat([coord, coord[:, -1:, :] + rcut], axis=1)
coord_r = paddle.take_along_axis(coord_pad, axis=1, indices=index) Note: The specific sentinel value is irrelevant because weight and diff are masked; the key is to avoid out-of-bounds and to keep gradients defined. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||
diff = coord_r - coord_l | ||||||||||||||||||||||||||||||
length = paddle.linalg.norm(diff, axis=-1, keepdim=True) | ||||||||||||||||||||||||||||||
# for index 0 nloc atom | ||||||||||||||||||||||||||||||
length = length + (~mask.unsqueeze(-1)).astype(length.dtype) | ||||||||||||||||||||||||||||||
t0 = 1 / (length + protection) | ||||||||||||||||||||||||||||||
t1 = diff / (length + protection) ** 2 | ||||||||||||||||||||||||||||||
weight = compute_smooth_weight(length, ruct_smth, rcut) | ||||||||||||||||||||||||||||||
weight = ( | ||||||||||||||||||||||||||||||
compute_smooth_weight(length, ruct_smth, rcut) | ||||||||||||||||||||||||||||||
if not use_exp_switch | ||||||||||||||||||||||||||||||
else compute_exp_sw(length, ruct_smth, rcut) | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
weight = weight * mask.unsqueeze(-1).astype(weight.dtype) | ||||||||||||||||||||||||||||||
if radial_only: | ||||||||||||||||||||||||||||||
env_mat = t0 * weight | ||||||||||||||||||||||||||||||
|
@@ -51,6 +58,7 @@ def prod_env_mat( | |||||||||||||||||||||||||||||
rcut_smth: float, | ||||||||||||||||||||||||||||||
radial_only: bool = False, | ||||||||||||||||||||||||||||||
protection: float = 0.0, | ||||||||||||||||||||||||||||||
use_exp_switch: bool = False, | ||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||
"""Generate smooth environment matrix from atom coordinates and other context. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
@@ -63,6 +71,7 @@ def prod_env_mat( | |||||||||||||||||||||||||||||
- rcut_smth: Smooth hyper-parameter for pair force & energy. | ||||||||||||||||||||||||||||||
- radial_only: Whether to return a full description or a radial-only descriptor. | ||||||||||||||||||||||||||||||
- protection: Protection parameter to prevent division by zero errors during calculations. | ||||||||||||||||||||||||||||||
- use_exp_switch: Whether to use the exponential switch function. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
Returns | ||||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||||
|
@@ -75,6 +84,7 @@ def prod_env_mat( | |||||||||||||||||||||||||||||
rcut_smth, | ||||||||||||||||||||||||||||||
radial_only, | ||||||||||||||||||||||||||||||
protection=protection, | ||||||||||||||||||||||||||||||
use_exp_switch=use_exp_switch, | ||||||||||||||||||||||||||||||
) # shape [n_atom, dim, 4 or 1] | ||||||||||||||||||||||||||||||
t_avg = mean[atype] # [n_atom, dim, 4 or 1] | ||||||||||||||||||||||||||||||
t_std = stddev[atype] # [n_atom, dim, 4 or 1] | ||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Same conditional reshard for virial difference.
Align with the pattern above to avoid attribute errors off-mesh.
📝 Committable suggestion
🤖 Prompt for AI Agents