Skip to content

Commit 0c68212

Browse files
committed
vastly faster VN neighborhood building
1 parent 1b6b332 commit 0c68212

File tree

1 file changed

+118
-9
lines changed

1 file changed

+118
-9
lines changed

src/pylattica/structures/square_grid/neighborhoods.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import numpy as np
2+
import rustworkx as rx
23
from ...core.coordinate_utils import get_points_in_cube
34
from ...core.neighborhood_builders import (
45
StochasticNeighborhoodBuilder,
56
MotifNeighborhoodBuilder,
67
DistanceNeighborhoodBuilder,
8+
NeighborhoodBuilder,
79
)
10+
from ...core.neighborhoods import Neighborhood
11+
from ...core.periodic_structure import PeriodicStructure
812

913

1014
class VonNeumannNbHood2DBuilder(MotifNeighborhoodBuilder):
@@ -28,25 +32,130 @@ def __init__(self, size=1):
2832
super().__init__(filtered_points)
2933

3034

31-
class VonNeumannNbHood3DBuilder(MotifNeighborhoodBuilder):
32-
"""A helper class for generating von Neumann type neighborhoods in square 3D structures."""
35+
class VonNeumannNbHood3DBuilder(NeighborhoodBuilder):
36+
"""Optimized Von Neumann neighborhood builder for simple cubic 3D grids.
37+
38+
Uses direct index math instead of coordinate lookups, providing
39+
much faster performance for large grids.
40+
41+
For a cubic grid of size n³:
42+
- Site i is at position (x, y, z) where x = i % n, y = (i // n) % n, z = i // n²
43+
- Neighbor at offset (dx, dy, dz) has ID: ((x+dx) % n) + ((y+dy) % n) * n + ((z+dz) % n) * n²
44+
"""
3345

3446
def __init__(self, size: int):
35-
"""Constructs the VonNeumannNbHood3D Builder
47+
"""Constructs the VonNeumannNbHood3DBuilder.
3648
3749
Parameters
3850
----------
3951
size : int
40-
The size of the neighborhood.
52+
The size of the neighborhood (Manhattan distance).
4153
"""
54+
# Generate Von Neumann neighborhood offsets (excluding origin)
4255
points = get_points_in_cube(-size, size + 1, 3)
56+
self._offsets = [
57+
tuple(point) for point in points
58+
if sum(np.abs(p) for p in point) <= size and any(p != 0 for p in point)
59+
]
60+
# Precompute distances for edge weights
61+
self._distances = {
62+
offset: np.sqrt(sum(p**2 for p in offset))
63+
for offset in self._offsets
64+
}
65+
# Cache for grid size (computed once per structure)
66+
self._cached_n = None
67+
self._cached_n_sites = None
4368

44-
filtered_points = []
45-
for point in points:
46-
if sum(np.abs(p) for p in point) <= size:
47-
filtered_points.append(point)
69+
def _get_grid_size(self, struct: PeriodicStructure) -> int:
70+
"""Infer grid size n from structure (cached)."""
71+
n_sites = len(struct.site_ids)
72+
if n_sites != self._cached_n_sites:
73+
n = int(round(n_sites ** (1/3)))
74+
if n ** 3 != n_sites:
75+
raise ValueError(f"Structure has {n_sites} sites, not a perfect cube.")
76+
self._cached_n = n
77+
self._cached_n_sites = n_sites
78+
return self._cached_n
4879

49-
super().__init__(filtered_points)
80+
def get_neighbors(self, curr_site: dict, struct: PeriodicStructure) -> list:
81+
"""Get neighbors of a site using fast index math.
82+
83+
Parameters
84+
----------
85+
curr_site : dict
86+
Site dictionary with 'id' key
87+
struct : PeriodicStructure
88+
The structure (used to infer grid size)
89+
90+
Returns
91+
-------
92+
list
93+
List of (neighbor_id, distance) tuples
94+
"""
95+
from ...core.constants import SITE_ID
96+
97+
n = self._get_grid_size(struct)
98+
site_id = curr_site[SITE_ID]
99+
100+
# Convert site ID to (x, y, z) coordinates
101+
x = site_id % n
102+
y = (site_id // n) % n
103+
z = site_id // (n * n)
104+
105+
neighbors = []
106+
for dx, dy, dz in self._offsets:
107+
# Compute neighbor coordinates with periodic boundary conditions
108+
nx = (x + dx) % n
109+
ny = (y + dy) % n
110+
nz = (z + dz) % n
111+
112+
# Convert back to site ID
113+
neighbor_id = nx + ny * n + nz * (n * n)
114+
neighbors.append((neighbor_id, self._distances[(dx, dy, dz)]))
115+
116+
return neighbors
117+
118+
def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood:
119+
"""Build neighborhood graph using vectorized index math.
120+
121+
This override provides much faster performance than the base class
122+
by computing all edges in bulk using numpy operations.
123+
"""
124+
n_sites = len(struct.site_ids)
125+
n = self._get_grid_size(struct)
126+
127+
graph = rx.PyDiGraph()
128+
129+
# Add all nodes at once
130+
graph.add_nodes_from(range(n_sites))
131+
132+
# Vectorized computation: create coordinate arrays for all sites
133+
site_ids = np.arange(n_sites, dtype=np.int64)
134+
x = site_ids % n
135+
y = (site_ids // n) % n
136+
z = site_ids // (n * n)
137+
138+
# Collect all edges across all offsets, then add in one batch
139+
all_edges = []
140+
for dx, dy, dz in self._offsets:
141+
# Compute neighbor coordinates with periodic boundary conditions
142+
nx = (x + dx) % n
143+
ny = (y + dy) % n
144+
nz = (z + dz) % n
145+
146+
# Convert to neighbor site IDs
147+
neighbor_ids = nx + ny * n + nz * (n * n)
148+
weight = self._distances[(dx, dy, dz)]
149+
150+
# Stack source, dest, weight as columns and extend
151+
# Use numpy operations to avoid Python loop overhead
152+
edge_data = np.column_stack([site_ids, neighbor_ids])
153+
all_edges.extend((int(s), int(d), weight) for s, d in edge_data)
154+
155+
# Add all edges in one batch
156+
graph.extend_from_weighted_edge_list(all_edges)
157+
158+
return Neighborhood(graph)
50159

51160

52161
class MooreNbHoodBuilder(MotifNeighborhoodBuilder):

0 commit comments

Comments
 (0)