Skip to content

Commit 385e091

Browse files
committed
Add cuML KNN for graph construction, better handling of non-PBC systems
1 parent 9a68e46 commit 385e091

File tree

10 files changed

+440
-55
lines changed

10 files changed

+440
-55
lines changed

CONTRIBUTING.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ pip install poetry # Install Poetry if you don't have it
77
poetry install
88
```
99

10+
Optionally, also install [cuML](https://docs.rapids.ai/install/) (requires CUDA):
11+
```bash
12+
pip install --extra-index-url=https://pypi.nvidia.com "cuml-cu11==25.2.*" # For cuda versions >=11.4, <11.8
13+
pip install --extra-index-url=https://pypi.nvidia.com "cuml-cu12==25.2.*" # For cuda versions >=12.0, <13.0
14+
```
15+
1016
### Running tests
1117

1218
The `orb_models` package uses `pytest` for testing. To run the tests, navigate to the root directory of the package and run the following command:

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ pip install orb-models
2222

2323
Orb models are expected to work on MacOS and Linux. Windows support is not guaranteed.
2424

25+
For large system (≳5k atoms PBC, or ≳30k atoms non-PBC) simulations we recommend installing [cuML](https://docs.rapids.ai/install/) (requires CUDA), which can significantly reduce graph creation time (2-10x) and improve GPU memory efficiency (2-100x):
26+
```bash
27+
pip install --extra-index-url=https://pypi.nvidia.com "cuml-cu11==25.2.*" # For cuda versions >=11.4, <11.8
28+
pip install --extra-index-url=https://pypi.nvidia.com "cuml-cu12==25.2.*" # For cuda versions >=12.0, <13.0
29+
```
30+
2531
### Updates
2632

2733
**Oct 2024**: We have released a new version of the models, `orb-v2`. This version has 2 major changes:

internal/check.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import ase
66
import torch
7+
import numpy as np
8+
79
from core.dataset import atomic_system as core_atomic_system
810
from core.models import load
911

@@ -21,7 +23,12 @@ def main(model: str, core_model: str):
2123
"""
2224
original_orbff, _, sys_config = load.load_model(core_model)
2325

24-
atoms = ase.Atoms("H2O", positions=[[0, 0, 0], [0, 0, 1.1], [0, 1.1, 0]])
26+
atoms = ase.Atoms(
27+
"H2O",
28+
positions=[[0, 0, 0], [0, 0, 1.1], [0, 1.1, 0]],
29+
cell=np.eye(3) * 2,
30+
pbc=True,
31+
)
2532

2633
graph_orig = core_atomic_system.ase_atoms_to_atom_graphs(atoms, sys_config)
2734
graph = atomic_system.ase_atoms_to_atom_graphs(atoms)

orb_models/forcefield/atomic_system.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,16 @@ def ase_atoms_to_atom_graphs(
171171
Args:
172172
atoms: ase.Atoms object
173173
wrap: whether to wrap atomic positions into the central unit cell (if there is one).
174-
edge_method: The method to use for edge creation:
175-
- knn_brute_force: Use brute force to find nearest neighbors.
176-
- knn_scipy (default): Use scipy to find nearest neighbors.
174+
edge_method (EdgeCreationMethod, optional): The method to use for graph edge construction.
175+
If None, the edge method is chosen as follows:
176+
* knn_brute_force: If device is not CPU, and cuML is not installed or num_atoms is < 5000 (PBC)
177+
or < 30000 (non-PBC).
178+
* knn_cuml_rbc: If device is not CPU, and cuML is installed, and num_atoms is >= 5000 (PBC) or
179+
>= 30000 (non-PBC).
180+
* knn_scipy (default): If device is CPU.
181+
On GPU, for num_atoms ≲ 5000 (PBC) or ≲ 30000 (non-PBC), knn_brute_force is faster than knn_cuml_*,
182+
but uses more memory. For num_atoms ≳ 5000 (PBC) or ≳ 30000 (non-PBC), knn_cuml_* is faster and uses
183+
less memory, but requires cuML to be installed. knn_scipy is typically fastest on the CPU.
177184
system_config: The system configuration to use for graph construction.
178185
max_num_neighbors: Maximum number of neighbors each node can send messages to.
179186
If None, will use system_config.max_num_neighbors.
@@ -214,13 +221,15 @@ def ase_atoms_to_atom_graphs(
214221
)
215222
positions = torch.from_numpy(atoms.positions)
216223
cell = torch.from_numpy(atoms.cell.array)
224+
pbc = torch.from_numpy(atoms.pbc)
217225
lattice = torch.from_numpy(cell_to_cellpar(cell))
218-
if wrap and torch.any(cell != 0):
226+
if wrap and (torch.any(cell != 0) and torch.any(pbc)):
219227
positions = feat_util.map_to_pbc_cell(positions, cell)
220228

221229
edge_index, edge_vectors, unit_shifts = feat_util.compute_pbc_radius_graph(
222230
positions=positions,
223231
cell=cell,
232+
pbc=pbc,
224233
radius=system_config.radius,
225234
max_number_neighbors=max_num_neighbors,
226235
edge_method=edge_method,
@@ -248,6 +257,7 @@ def ase_atoms_to_atom_graphs(
248257
graph_feats = {
249258
**atoms.info.get("graph_features", {}),
250259
"cell": cell,
260+
"pbc": pbc,
251261
"lattice": lattice,
252262
}
253263

orb_models/forcefield/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def cell(self, val: torch.Tensor):
113113
assert self.system_features
114114
self.system_features["cell"] = val
115115

116+
@property
117+
def pbc(self):
118+
"""Get pbc."""
119+
assert self.system_features
120+
return self.system_features.get("pbc")
121+
116122
def compute_differentiable_edge_vectors(
117123
self,
118124
use_stress_displacement: bool = True,
@@ -477,6 +483,7 @@ def refeaturize_atomgraphs(
477483
) = featurization_utilities.batch_compute_pbc_radius_graph(
478484
positions=positions,
479485
cells=cell,
486+
pbc=atoms.pbc,
480487
radius=atoms.radius,
481488
n_node=num_atoms,
482489
max_number_neighbors=atoms.max_num_neighbors,

orb_models/forcefield/featurization_utilities.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
"""Featurization utilities for molecular models."""
22

3+
import typing
34
from typing import Optional, Tuple, Union, Literal, List
45

56
import ase
67
import numpy as np
78
import torch
89

10+
try:
11+
import cuml
12+
except ImportError:
13+
cuml = None
14+
915
from scipy.spatial import KDTree as SciKDTree
1016

1117

@@ -29,6 +35,25 @@ def get_device(
2935
return torch.device(requested_device)
3036

3137

38+
def get_default_edge_method(
39+
device: torch.device, num_atoms: int, is_periodic: bool
40+
) -> EdgeCreationMethod:
41+
"""Get the default edge method for a given device and number of atoms."""
42+
if device.type != "cpu":
43+
if (
44+
cuml is None
45+
or (is_periodic and num_atoms < 5_000)
46+
or (not is_periodic and num_atoms < 30_000)
47+
):
48+
edge_method = "knn_brute_force"
49+
else:
50+
edge_method = "knn_cuml_rbc"
51+
else:
52+
edge_method = "knn_scipy"
53+
assert edge_method in typing.get_args(EdgeCreationMethod)
54+
return edge_method # type: ignore
55+
56+
3257
def get_atom_embedding(atoms: ase.Atoms, k_hot: bool = False) -> torch.Tensor:
3358
"""Get an atomic embedding."""
3459
atomic_numbers = torch.from_numpy(atoms.numbers).to(torch.long)
@@ -433,6 +458,7 @@ def compute_pbc_radius_graph(
433458
*,
434459
positions: torch.Tensor,
435460
cell: torch.Tensor,
461+
pbc: torch.Tensor,
436462
radius: Union[float, torch.Tensor],
437463
max_number_neighbors: int,
438464
edge_method: Optional[EdgeCreationMethod] = None,
@@ -446,13 +472,21 @@ def compute_pbc_radius_graph(
446472
Args:
447473
positions (torch.Tensor): 3D positions of particles. Shape [num_particles, 3].
448474
cell (torch.Tensor): A 3x3 matrix where the lattice vectors are rows or columns.
475+
pbc (torch.Tensor): A boolean tensor of shape [3] indicating which directions are periodic.
449476
radius (Union[float, torch.tensor]): The radius within which to connect atoms.
450-
max_number_neighbors (int, optional): The maximum number of neighbors for each particle. Defaults to 20.
477+
max_number_neighbors (int): The maximum number of neighbors for each particle.
451478
edge_method (EdgeCreationMethod, optional): The method to use for graph edge construction.
452-
Defaults to None, in which case knn_brute_force is used if we are on GPU (2-6x faster),
453-
otherwise knn_scipy. More details here: https://github.com/orbital-materials/orb/pull/766
454-
n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1.
455-
device (Optional[Union[torch.device, str, int]], optional): The device to use for computation.
479+
Defaults to None, in which case edge method is chosen as follows:
480+
* knn_brute_force: If device is not CPU, and cuML is not installed or num_atoms is < 5000 (PBC)
481+
or < 30000 (non-PBC).
482+
* knn_cuml_rbc: If device is not CPU, and cuML is installed, and num_atoms is >= 5000 (PBC) or
483+
>= 30000 (non-PBC).
484+
* knn_scipy: If device is CPU.
485+
On GPU, for num_atoms ≲ 5000 (PBC) or ≲ 30000 (non-PBC), knn_brute_force is faster than knn_cuml_*,
486+
but uses more memory. For num_atoms ≳ 5000 (PBC) or ≳ 30000 (non-PBC), knn_cuml_* is faster and uses
487+
less memory, but requires cuML to be installed. knn_scipy is typically fastest on the CPU.
488+
n_workers (int, optional): The number of workers for KDTree construction in knn_scipy. Defaults to 1.
489+
device (Union[torch.device, str, int], optional): The device to use for computation.
456490
Defaults to None, in which case GPU is used if available.
457491
half_supercell (bool): Whether to use half the supercell for graph construction, and then symmetrize.
458492
This flag does not affect the resulting graph; it is purely an optimization that can double
@@ -474,16 +508,16 @@ def compute_pbc_radius_graph(
474508

475509
natoms = positions.shape[0]
476510
half_supercell = half_supercell and bool(torch.any(cell != 0.0))
511+
is_periodic = bool(torch.any(cell != 0.0).item() and torch.any(pbc).item())
477512

478513
device = get_device(requested_device=device)
479-
if edge_method is None:
480-
edge_method = "knn_brute_force" if device.type != "cpu" else "knn_scipy"
481-
if edge_method == "knn_brute_force":
482-
# if knn brute force, then try to place tensors on the gpu
514+
edge_method = edge_method or get_default_edge_method(device, natoms, is_periodic)
515+
if edge_method == "knn_brute_force" or edge_method.startswith("knn_cuml_"):
516+
# if knn_brute_force or knn_cuml_*, then try to place tensors on the gpu if device is not provided
483517
positions = positions.to(device)
484518
cell = cell.to(device)
485519

486-
if torch.any(cell != 0.0):
520+
if is_periodic:
487521
if half_supercell:
488522
supercell_positions, integer_offsets = construct_half_3x3x3_supercell(
489523
positions=positions, cell=cell
@@ -575,6 +609,8 @@ def compute_supercell_neighbors(
575609
edge_method (EdgeCreationMethod): The method to use for graph edge construction:
576610
- knn_brute_force: Use brute force knn implementation: compute all pairwise distances between
577611
positions and supercell_positions, and subsequently filter edges based on radius and max_num_neighbors.
612+
- knn_cuml_rbc: Use cuML's random-ball algorithm implementation.
613+
- knn_cuml_brute: Use cuML's brute force implementation.
578614
- knn_scipy: Use scipy's KDTree implementation.
579615
n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1.
580616
"""
@@ -588,6 +624,33 @@ def compute_supercell_neighbors(
588624
within_radius = distances[:, 1:] < (radius + 1e-6)
589625
num_neighbors_per_sender = within_radius.sum(-1)
590626
supercell_receivers = supercell_receivers[:, 1:][within_radius]
627+
elif edge_method.startswith("knn_cuml_"):
628+
if cuml is None:
629+
raise ImportError(
630+
"cuML is not installed. Please install cuML: https://docs.rapids.ai/install/."
631+
)
632+
assert (
633+
supercell_positions.device.type == "cuda"
634+
and central_cell_positions.device.type == "cuda"
635+
), "cuML KNN is only supported on CUDA devices"
636+
algorithm = edge_method.split("_")[-1]
637+
k = min(max_num_neighbors + 1, len(supercell_positions))
638+
knn = cuml.neighbors.NearestNeighbors(
639+
n_neighbors=k,
640+
algorithm=algorithm,
641+
metric="euclidean",
642+
)
643+
knn.fit(supercell_positions)
644+
distances, supercell_receivers = knn.kneighbors(
645+
central_cell_positions, return_distance=True
646+
)
647+
# Convert from CuPy arrays to PyTorch tensors
648+
distances = torch.as_tensor(distances)
649+
supercell_receivers = torch.as_tensor(supercell_receivers)
650+
# remove self-edges and edges beyond radius
651+
within_radius = distances[:, 1:] < (radius + 1e-6)
652+
num_neighbors_per_sender = within_radius.sum(-1)
653+
supercell_receivers = supercell_receivers[:, 1:][within_radius]
591654
elif edge_method == "knn_scipy":
592655
tree_data = supercell_positions.clone().detach().cpu().numpy()
593656
tree_query = central_cell_positions.clone().detach().cpu().numpy()
@@ -600,6 +663,9 @@ def compute_supercell_neighbors(
600663
workers=n_workers,
601664
p=2,
602665
)
666+
if len(supercell_receivers.shape) == 1:
667+
supercell_receivers = supercell_receivers[None, :]
668+
603669
# Remove the self-edge that will be closest
604670
supercell_receivers = np.array(supercell_receivers)[:, 1:] # type: ignore
605671

@@ -688,6 +754,7 @@ def batch_compute_pbc_radius_graph(
688754
*,
689755
positions: torch.Tensor,
690756
cells: torch.Tensor,
757+
pbc: torch.Tensor,
691758
radius: Union[float, torch.Tensor],
692759
n_node: torch.Tensor,
693760
max_number_neighbors: torch.Tensor,
@@ -704,13 +771,21 @@ def batch_compute_pbc_radius_graph(
704771
Args:
705772
positions (torch.Tensor): 3D positions of a batch of particles. Shape [num_particles, 3].
706773
cells (torch.Tensor): A batch of 3x3 matrices where the lattice vectors are rows.
774+
pbc (torch.Tensor): A batch of boolean tensors of shape [3] indicating which directions are periodic.
707775
radius (Union[float, torch.tensor]): The radius within which to connect atoms.
708776
n_node (torch.Tensor): A vector where each element indicates the number of particles in each element of
709777
the batch. Of size len(batch).
710778
max_number_neighbors (torch.Tensor): The maximum number of neighbors for each particle.
711779
edge_method (EdgeCreationMethod, optional): The method to use for graph edge construction.
712-
Defaults to None, in which case knn_brute_force is used if we are on GPU (2-6x faster),
713-
otherwise knn_scipy. More details here: https://github.com/orbital-materials/orb/pull/766
780+
Defaults to None, in which case edge method is chosen as follows:
781+
* knn_brute_force: If device is not CPU, and cuML is not installed or num_atoms is < 5000 (PBC)
782+
or < 30000 (non-PBC).
783+
* knn_cuml_rbc: If device is not CPU, and cuML is installed, and num_atoms is >= 5000 (PBC) or
784+
>= 30000 (non-PBC).
785+
* knn_scipy: If device is CPU.
786+
On GPU, for num_atoms ≲ 5000 (PBC) or ≲ 30000 (non-PBC), knn_brute_force is faster than knn_cuml_*,
787+
but uses more memory. For num_atoms ≳ 5000 (PBC) or ≳ 30000 (non-PBC), knn_cuml_* is faster and uses
788+
less memory, but requires cuML to be installed. knn_scipy is typically fastest on the CPU.
714789
half_supercell (bool): Whether to use half the supercell for graph construction, and then symmetrize.
715790
This flag does not affect the resulting graph; it is purely an optimization that can double
716791
throughput and half memory for very large cells (e.g. 10k+ atoms). For smaller systems, it can harm
@@ -731,16 +806,17 @@ def batch_compute_pbc_radius_graph(
731806
num_edges = []
732807
all_unit_shifts = []
733808

734-
device = positions.device
735-
for p, pbc, mn in zip(
809+
for p, cell, pbc, mn in zip(
736810
torch.tensor_split(positions, torch.cumsum(n_node, 0)[:-1].cpu()),
737811
cells,
812+
pbc,
738813
max_number_neighbors,
739814
strict=True,
740815
):
741816
edges, vectors, unit_shifts = compute_pbc_radius_graph(
742817
positions=p,
743-
cell=pbc,
818+
cell=cell,
819+
pbc=pbc,
744820
radius=radius,
745821
max_number_neighbors=int(mn),
746822
edge_method=edge_method,

0 commit comments

Comments
 (0)