Skip to content

Commit ef1808a

Browse files
authored
force fp64 in brute force edge construction (#91)
1 parent c87438b commit ef1808a

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

orb_models/forcefield/featurization_utilities.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,11 +620,19 @@ def compute_supercell_neighbors(
620620
n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1.
621621
"""
622622
if edge_method == "knn_brute_force":
623-
distances = torch.cdist(central_cell_positions, supercell_positions)
623+
624+
# Always use float64 for distance calculations, because
625+
# torch.cdist can be quite inprecise for float32 when use_mm_for_euclid_dist is True.
626+
# This can lead to incorrect edge selection.
627+
original_dtype = central_cell_positions.dtype
628+
central_cell_positions_f64 = central_cell_positions.to(torch.float64)
629+
supercell_positions_f64 = supercell_positions.to(torch.float64)
630+
distances = torch.cdist(central_cell_positions_f64, supercell_positions_f64)
624631
k = min(max_num_neighbors + 1, len(supercell_positions))
625632
distances, supercell_receivers = torch.topk(
626633
distances, k=k, largest=False, sorted=True
627634
)
635+
distances = distances.to(original_dtype)
628636
# remove self-edges and edges beyond radius
629637
within_radius = distances[:, 1:] < (radius + 1e-6)
630638
num_neighbors_per_sender = within_radius.sum(-1)

0 commit comments

Comments
 (0)