diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 6933fdc19a..a77e5391f8 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -392,6 +392,11 @@ def _get_pairwise_dist(coords: torch.Tensor, nlist: torch.Tensor) -> torch.Tenso ------- torch.Tensor The pairwise distance between the atoms (nframes, nloc, nnei). + + Notes + ----- + Safe gradient implementation: when diff is zero (padding entries), + both distance and gradient are zero. """ nframes, nloc, nnei = nlist.shape coord_l = coords[:, :nloc].view(nframes, -1, 1, 3) @@ -399,7 +404,17 @@ def _get_pairwise_dist(coords: torch.Tensor, nlist: torch.Tensor) -> torch.Tenso coord_r = torch.gather(coords, 1, index) coord_r = coord_r.view(nframes, nloc, nnei, 3) diff = coord_r - coord_l - pairwise_rr = torch.linalg.norm(diff, dim=-1, keepdim=True).squeeze(-1) + diff_sq = torch.sum(diff * diff, dim=-1, keepdim=True) + + # When diff is zero, output is zero and gradient is also zero + mask = diff_sq.squeeze(-1) > 0 + pairwise_rr = torch.where( + mask.unsqueeze(-1), + torch.sqrt( + torch.where(mask.unsqueeze(-1), diff_sq, torch.ones_like(diff_sq)) + ), + torch.zeros_like(diff_sq), + ).squeeze(-1) return pairwise_rr @staticmethod