11"""Featurization utilities for molecular models."""
22
3+ import typing
34from typing import Optional , Tuple , Union , Literal , List
45
56import ase
67import numpy as np
78import torch
89
10+ try :
11+ import cuml
12+ except ImportError :
13+ cuml = None
14+
915from scipy .spatial import KDTree as SciKDTree
1016
1117
@@ -29,6 +35,25 @@ def get_device(
2935 return torch .device (requested_device )
3036
3137
38+ def get_default_edge_method (
39+ device : torch .device , num_atoms : int , is_periodic : bool
40+ ) -> EdgeCreationMethod :
41+ """Get the default edge method for a given device and number of atoms."""
42+ if device .type != "cpu" :
43+ if (
44+ cuml is None
45+ or (is_periodic and num_atoms < 5_000 )
46+ or (not is_periodic and num_atoms < 30_000 )
47+ ):
48+ edge_method = "knn_brute_force"
49+ else :
50+ edge_method = "knn_cuml_rbc"
51+ else :
52+ edge_method = "knn_scipy"
53+ assert edge_method in typing .get_args (EdgeCreationMethod )
54+ return edge_method # type: ignore
55+
56+
3257def get_atom_embedding (atoms : ase .Atoms , k_hot : bool = False ) -> torch .Tensor :
3358 """Get an atomic embedding."""
3459 atomic_numbers = torch .from_numpy (atoms .numbers ).to (torch .long )
@@ -433,6 +458,7 @@ def compute_pbc_radius_graph(
433458 * ,
434459 positions : torch .Tensor ,
435460 cell : torch .Tensor ,
461+ pbc : torch .Tensor ,
436462 radius : Union [float , torch .Tensor ],
437463 max_number_neighbors : int ,
438464 edge_method : Optional [EdgeCreationMethod ] = None ,
@@ -446,13 +472,21 @@ def compute_pbc_radius_graph(
446472 Args:
447473 positions (torch.Tensor): 3D positions of particles. Shape [num_particles, 3].
448474 cell (torch.Tensor): A 3x3 matrix where the lattice vectors are rows or columns.
475+ pbc (torch.Tensor): A boolean tensor of shape [3] indicating which directions are periodic.
449476 radius (Union[float, torch.tensor]): The radius within which to connect atoms.
450- max_number_neighbors (int, optional ): The maximum number of neighbors for each particle. Defaults to 20 .
477+ max_number_neighbors (int): The maximum number of neighbors for each particle.
451478 edge_method (EdgeCreationMethod, optional): The method to use for graph edge construction.
452- Defaults to None, in which case knn_brute_force is used if we are on GPU (2-6x faster),
453- otherwise knn_scipy. More details here: https://github.com/orbital-materials/orb/pull/766
454- n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1.
455- device (Optional[Union[torch.device, str, int]], optional): The device to use for computation.
479+ Defaults to None, in which case edge method is chosen as follows:
480+ * knn_brute_force: If device is not CPU, and cuML is not installed or num_atoms is < 5000 (PBC)
481+ or < 30000 (non-PBC).
482+ * knn_cuml_rbc: If device is not CPU, and cuML is installed, and num_atoms is >= 5000 (PBC) or
483+ >= 30000 (non-PBC).
484+ * knn_scipy: If device is CPU.
485+ On GPU, for num_atoms ≲ 5000 (PBC) or ≲ 30000 (non-PBC), knn_brute_force is faster than knn_cuml_*,
486+ but uses more memory. For num_atoms ≳ 5000 (PBC) or ≳ 30000 (non-PBC), knn_cuml_* is faster and uses
487+ less memory, but requires cuML to be installed. knn_scipy is typically fastest on the CPU.
488+ n_workers (int, optional): The number of workers for KDTree construction in knn_scipy. Defaults to 1.
489+ device (Union[torch.device, str, int], optional): The device to use for computation.
456490 Defaults to None, in which case GPU is used if available.
457491 half_supercell (bool): Whether to use half the supercell for graph construction, and then symmetrize.
458492 This flag does not affect the resulting graph; it is purely an optimization that can double
@@ -474,16 +508,16 @@ def compute_pbc_radius_graph(
474508
475509 natoms = positions .shape [0 ]
476510 half_supercell = half_supercell and bool (torch .any (cell != 0.0 ))
511+ is_periodic = bool (torch .any (cell != 0.0 ).item () and torch .any (pbc ).item ())
477512
478513 device = get_device (requested_device = device )
479- if edge_method is None :
480- edge_method = "knn_brute_force" if device .type != "cpu" else "knn_scipy"
481- if edge_method == "knn_brute_force" :
482- # if knn brute force, then try to place tensors on the gpu
514+ edge_method = edge_method or get_default_edge_method (device , natoms , is_periodic )
515+ if edge_method == "knn_brute_force" or edge_method .startswith ("knn_cuml_" ):
516+ # if knn_brute_force or knn_cuml_*, then try to place tensors on the gpu if device is not provided
483517 positions = positions .to (device )
484518 cell = cell .to (device )
485519
486- if torch . any ( cell != 0.0 ) :
520+ if is_periodic :
487521 if half_supercell :
488522 supercell_positions , integer_offsets = construct_half_3x3x3_supercell (
489523 positions = positions , cell = cell
@@ -575,6 +609,8 @@ def compute_supercell_neighbors(
575609 edge_method (EdgeCreationMethod): The method to use for graph edge construction:
576610 - knn_brute_force: Use brute force knn implementation: compute all pairwise distances between
577611 positions and supercell_positions, and subsequently filter edges based on radius and max_num_neighbors.
612+ - knn_cuml_rbc: Use cuML's random-ball algorithm implementation.
613+ - knn_cuml_brute: Use cuML's brute force implementation.
578614 - knn_scipy: Use scipy's KDTree implementation.
579615 n_workers (int, optional): The number of workers to use for KDTree construction. Defaults to 1.
580616 """
@@ -588,6 +624,33 @@ def compute_supercell_neighbors(
588624 within_radius = distances [:, 1 :] < (radius + 1e-6 )
589625 num_neighbors_per_sender = within_radius .sum (- 1 )
590626 supercell_receivers = supercell_receivers [:, 1 :][within_radius ]
627+ elif edge_method .startswith ("knn_cuml_" ):
628+ if cuml is None :
629+ raise ImportError (
630+ "cuML is not installed. Please install cuML: https://docs.rapids.ai/install/."
631+ )
632+ assert (
633+ supercell_positions .device .type == "cuda"
634+ and central_cell_positions .device .type == "cuda"
635+ ), "cuML KNN is only supported on CUDA devices"
636+ algorithm = edge_method .split ("_" )[- 1 ]
637+ k = min (max_num_neighbors + 1 , len (supercell_positions ))
638+ knn = cuml .neighbors .NearestNeighbors (
639+ n_neighbors = k ,
640+ algorithm = algorithm ,
641+ metric = "euclidean" ,
642+ )
643+ knn .fit (supercell_positions )
644+ distances , supercell_receivers = knn .kneighbors (
645+ central_cell_positions , return_distance = True
646+ )
647+ # Convert from CuPy arrays to PyTorch tensors
648+ distances = torch .as_tensor (distances )
649+ supercell_receivers = torch .as_tensor (supercell_receivers )
650+ # remove self-edges and edges beyond radius
651+ within_radius = distances [:, 1 :] < (radius + 1e-6 )
652+ num_neighbors_per_sender = within_radius .sum (- 1 )
653+ supercell_receivers = supercell_receivers [:, 1 :][within_radius ]
591654 elif edge_method == "knn_scipy" :
592655 tree_data = supercell_positions .clone ().detach ().cpu ().numpy ()
593656 tree_query = central_cell_positions .clone ().detach ().cpu ().numpy ()
@@ -600,6 +663,9 @@ def compute_supercell_neighbors(
600663 workers = n_workers ,
601664 p = 2 ,
602665 )
666+ if len (supercell_receivers .shape ) == 1 :
667+ supercell_receivers = supercell_receivers [None , :]
668+
603669 # Remove the self-edge that will be closest
604670 supercell_receivers = np .array (supercell_receivers )[:, 1 :] # type: ignore
605671
@@ -688,6 +754,7 @@ def batch_compute_pbc_radius_graph(
688754 * ,
689755 positions : torch .Tensor ,
690756 cells : torch .Tensor ,
757+ pbc : torch .Tensor ,
691758 radius : Union [float , torch .Tensor ],
692759 n_node : torch .Tensor ,
693760 max_number_neighbors : torch .Tensor ,
@@ -704,13 +771,21 @@ def batch_compute_pbc_radius_graph(
704771 Args:
705772 positions (torch.Tensor): 3D positions of a batch of particles. Shape [num_particles, 3].
706773 cells (torch.Tensor): A batch of 3x3 matrices where the lattice vectors are rows.
774+ pbc (torch.Tensor): A batch of boolean tensors of shape [3] indicating which directions are periodic.
707775 radius (Union[float, torch.tensor]): The radius within which to connect atoms.
708776 n_node (torch.Tensor): A vector where each element indicates the number of particles in each element of
709777 the batch. Of size len(batch).
710778 max_number_neighbors (torch.Tensor): The maximum number of neighbors for each particle.
711779 edge_method (EdgeCreationMethod, optional): The method to use for graph edge construction.
712- Defaults to None, in which case knn_brute_force is used if we are on GPU (2-6x faster),
713- otherwise knn_scipy. More details here: https://github.com/orbital-materials/orb/pull/766
780+ Defaults to None, in which case edge method is chosen as follows:
781+ * knn_brute_force: If device is not CPU, and cuML is not installed or num_atoms is < 5000 (PBC)
782+ or < 30000 (non-PBC).
783+ * knn_cuml_rbc: If device is not CPU, and cuML is installed, and num_atoms is >= 5000 (PBC) or
784+ >= 30000 (non-PBC).
785+ * knn_scipy: If device is CPU.
786+ On GPU, for num_atoms ≲ 5000 (PBC) or ≲ 30000 (non-PBC), knn_brute_force is faster than knn_cuml_*,
787+ but uses more memory. For num_atoms ≳ 5000 (PBC) or ≳ 30000 (non-PBC), knn_cuml_* is faster and uses
788+ less memory, but requires cuML to be installed. knn_scipy is typically fastest on the CPU.
714789 half_supercell (bool): Whether to use half the supercell for graph construction, and then symmetrize.
715790 This flag does not affect the resulting graph; it is purely an optimization that can double
716791 throughput and half memory for very large cells (e.g. 10k+ atoms). For smaller systems, it can harm
@@ -731,16 +806,17 @@ def batch_compute_pbc_radius_graph(
731806 num_edges = []
732807 all_unit_shifts = []
733808
734- device = positions .device
735- for p , pbc , mn in zip (
809+ for p , cell , pbc , mn in zip (
736810 torch .tensor_split (positions , torch .cumsum (n_node , 0 )[:- 1 ].cpu ()),
737811 cells ,
812+ pbc ,
738813 max_number_neighbors ,
739814 strict = True ,
740815 ):
741816 edges , vectors , unit_shifts = compute_pbc_radius_graph (
742817 positions = p ,
743- cell = pbc ,
818+ cell = cell ,
819+ pbc = pbc ,
744820 radius = radius ,
745821 max_number_neighbors = int (mn ),
746822 edge_method = edge_method ,
0 commit comments