11import numpy as np
2+ import rustworkx as rx
23from ...core .coordinate_utils import get_points_in_cube
34from ...core .neighborhood_builders import (
45 StochasticNeighborhoodBuilder ,
56 MotifNeighborhoodBuilder ,
67 DistanceNeighborhoodBuilder ,
8+ NeighborhoodBuilder ,
79)
10+ from ...core .neighborhoods import Neighborhood
11+ from ...core .periodic_structure import PeriodicStructure
812
913
1014class VonNeumannNbHood2DBuilder (MotifNeighborhoodBuilder ):
@@ -28,25 +32,130 @@ def __init__(self, size=1):
2832 super ().__init__ (filtered_points )
2933
3034
31- class VonNeumannNbHood3DBuilder (MotifNeighborhoodBuilder ):
32- """A helper class for generating von Neumann type neighborhoods in square 3D structures."""
35+ class VonNeumannNbHood3DBuilder (NeighborhoodBuilder ):
36+ """Optimized Von Neumann neighborhood builder for simple cubic 3D grids.
37+
38+ Uses direct index math instead of coordinate lookups, providing
39+ much faster performance for large grids.
40+
41+ For a cubic grid of size n³:
42+ - Site i is at position (x, y, z) where x = i % n, y = (i // n) % n, z = i // n²
43+ - Neighbor at offset (dx, dy, dz) has ID: ((x+dx) % n) + ((y+dy) % n) * n + ((z+dz) % n) * n²
44+ """
3345
3446 def __init__ (self , size : int ):
35- """Constructs the VonNeumannNbHood3D Builder
47+ """Constructs the VonNeumannNbHood3DBuilder.
3648
3749 Parameters
3850 ----------
3951 size : int
40- The size of the neighborhood.
52+ The size of the neighborhood (Manhattan distance) .
4153 """
54+ # Generate Von Neumann neighborhood offsets (excluding origin)
4255 points = get_points_in_cube (- size , size + 1 , 3 )
56+ self ._offsets = [
57+ tuple (point ) for point in points
58+ if sum (np .abs (p ) for p in point ) <= size and any (p != 0 for p in point )
59+ ]
60+ # Precompute distances for edge weights
61+ self ._distances = {
62+ offset : np .sqrt (sum (p ** 2 for p in offset ))
63+ for offset in self ._offsets
64+ }
65+ # Cache for grid size (computed once per structure)
66+ self ._cached_n = None
67+ self ._cached_n_sites = None
4368
44- filtered_points = []
45- for point in points :
46- if sum (np .abs (p ) for p in point ) <= size :
47- filtered_points .append (point )
69+ def _get_grid_size (self , struct : PeriodicStructure ) -> int :
70+ """Infer grid size n from structure (cached)."""
71+ n_sites = len (struct .site_ids )
72+ if n_sites != self ._cached_n_sites :
73+ n = int (round (n_sites ** (1 / 3 )))
74+ if n ** 3 != n_sites :
75+ raise ValueError (f"Structure has { n_sites } sites, not a perfect cube." )
76+ self ._cached_n = n
77+ self ._cached_n_sites = n_sites
78+ return self ._cached_n
4879
49- super ().__init__ (filtered_points )
80+ def get_neighbors (self , curr_site : dict , struct : PeriodicStructure ) -> list :
81+ """Get neighbors of a site using fast index math.
82+
83+ Parameters
84+ ----------
85+ curr_site : dict
86+ Site dictionary with 'id' key
87+ struct : PeriodicStructure
88+ The structure (used to infer grid size)
89+
90+ Returns
91+ -------
92+ list
93+ List of (neighbor_id, distance) tuples
94+ """
95+ from ...core .constants import SITE_ID
96+
97+ n = self ._get_grid_size (struct )
98+ site_id = curr_site [SITE_ID ]
99+
100+ # Convert site ID to (x, y, z) coordinates
101+ x = site_id % n
102+ y = (site_id // n ) % n
103+ z = site_id // (n * n )
104+
105+ neighbors = []
106+ for dx , dy , dz in self ._offsets :
107+ # Compute neighbor coordinates with periodic boundary conditions
108+ nx = (x + dx ) % n
109+ ny = (y + dy ) % n
110+ nz = (z + dz ) % n
111+
112+ # Convert back to site ID
113+ neighbor_id = nx + ny * n + nz * (n * n )
114+ neighbors .append ((neighbor_id , self ._distances [(dx , dy , dz )]))
115+
116+ return neighbors
117+
118+ def get (self , struct : PeriodicStructure , site_class : str = None ) -> Neighborhood :
119+ """Build neighborhood graph using vectorized index math.
120+
121+ This override provides much faster performance than the base class
122+ by computing all edges in bulk using numpy operations.
123+ """
124+ n_sites = len (struct .site_ids )
125+ n = self ._get_grid_size (struct )
126+
127+ graph = rx .PyDiGraph ()
128+
129+ # Add all nodes at once
130+ graph .add_nodes_from (range (n_sites ))
131+
132+ # Vectorized computation: create coordinate arrays for all sites
133+ site_ids = np .arange (n_sites , dtype = np .int64 )
134+ x = site_ids % n
135+ y = (site_ids // n ) % n
136+ z = site_ids // (n * n )
137+
138+ # Collect all edges across all offsets, then add in one batch
139+ all_edges = []
140+ for dx , dy , dz in self ._offsets :
141+ # Compute neighbor coordinates with periodic boundary conditions
142+ nx = (x + dx ) % n
143+ ny = (y + dy ) % n
144+ nz = (z + dz ) % n
145+
146+ # Convert to neighbor site IDs
147+ neighbor_ids = nx + ny * n + nz * (n * n )
148+ weight = self ._distances [(dx , dy , dz )]
149+
150+ # Stack source, dest, weight as columns and extend
151+ # Use numpy operations to avoid Python loop overhead
152+ edge_data = np .column_stack ([site_ids , neighbor_ids ])
153+ all_edges .extend ((int (s ), int (d ), weight ) for s , d in edge_data )
154+
155+ # Add all edges in one batch
156+ graph .extend_from_weighted_edge_list (all_edges )
157+
158+ return Neighborhood (graph )
50159
51160
52161class MooreNbHoodBuilder (MotifNeighborhoodBuilder ):
0 commit comments