Skip to content

Commit 340cf9c

Browse files
authored
Merge pull request #23 from mcgalcode/improvements
Performance improvements
2 parents 1ca08f0 + e5b01fc commit 340cf9c

18 files changed

+564
-77
lines changed

.prospector.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ max-line-length: 120
22
test-warnings: false
33
doc-warnings: false
44
strictness: medium
5+
with: []
6+
uses: []
57
ignore-paths:
68
- docs
79
- tests
@@ -21,6 +23,7 @@ pycodestyle:
2123

2224
pylint:
2325
disable:
26+
- django-not-available
2427
- unsubscriptable-object
2528
- invalid-name
2629
- arguments-differ # to account for jobflow

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ enabled = true
113113

114114
[dependency-groups]
115115
dev = [
116+
"black>=24.8.0",
117+
"prospector>=1.10.3",
116118
"pytest>=7.1.3",
117119
"pytest-cov>=4.0.0",
118-
]
120+
]

src/pylattica/core/basic_controller.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,18 @@ class has a single responsibility, which is to implement the update
1515
SimulationState will be passed to this method, along with the ID of
1616
the site at which the update rule should be applied. It is up to the
1717
user to decide what updates should be produced using this information.
18+
19+
Attributes
20+
----------
21+
max_history : int, optional
22+
Maximum number of step diffs to keep in memory during simulation.
23+
Set this to limit memory usage for long simulations. When exceeded,
24+
older steps are checkpointed and dropped. Default is None (unlimited).
1825
"""
1926

27+
# Override this in subclasses to limit memory usage
28+
max_history: int = None
29+
2030
@abstractmethod
2131
def get_state_update(self, site_id: int, prev_state: SimulationState):
2232
pass # pragma: no cover
@@ -25,7 +35,10 @@ def pre_run(self, initial_state: SimulationState) -> None:
2535
pass
2636

2737
def get_random_site(self, state: SimulationState):
28-
return random.randint(0, len(state.site_ids()) - 1)
38+
# Use state.size (O(1)) instead of len(state.site_ids()) which is O(n)
39+
return random.randint(0, state.size - 1)
2940

3041
def instantiate_result(self, starting_state: SimulationState):
31-
return SimulationResult(starting_state=starting_state)
42+
return SimulationResult(
43+
starting_state=starting_state, max_history=self.max_history
44+
)

src/pylattica/core/neighborhood_builders.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -139,16 +139,14 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
139139
else:
140140
sites_to_process = struct.sites(site_class=site_class)
141141

142-
n_sites = len(all_sites)
143-
144142
# Extract locations and IDs as arrays for vectorized operations
145143
locations = np.array([s[LOCATION] for s in all_sites])
146144
site_ids = np.array([s[SITE_ID] for s in all_sites])
147145

148146
# Convert to fractional coordinates for periodic KD-tree
149-
frac_coords = np.array([
150-
struct.lattice.get_fractional_coords(loc) for loc in locations
151-
])
147+
frac_coords = np.array(
148+
[struct.lattice.get_fractional_coords(loc) for loc in locations]
149+
)
152150

153151
# Compute the maximum fractional radius that could correspond to
154152
# the Cartesian cutoff. For non-orthogonal lattices, we need to use
@@ -163,9 +161,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
163161
dim = struct.lattice.dim
164162

165163
# Build boxsize array: 1.0 for periodic dimensions, large value for non-periodic
166-
boxsize = np.array([
167-
1.0 if periodic[i] else 1e10 for i in range(dim)
168-
])
164+
boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)])
169165

170166
# Wrap fractional coordinates to [0, 1) for periodic dimensions
171167
frac_coords_wrapped = frac_coords.copy()
@@ -176,9 +172,6 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
176172
# Build KD-tree with periodic boundary conditions
177173
tree = cKDTree(frac_coords_wrapped, boxsize=boxsize)
178174

179-
# Create index mapping from site_id to array index
180-
id_to_idx = {sid: idx for idx, sid in enumerate(site_ids)}
181-
182175
# Process each site
183176
sites_to_process_ids = set(s[SITE_ID] for s in sites_to_process)
184177

@@ -296,9 +289,9 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
296289
site_ids = np.array([s[SITE_ID] for s in all_sites])
297290

298291
# Convert to fractional coordinates for periodic KD-tree
299-
frac_coords = np.array([
300-
struct.lattice.get_fractional_coords(loc) for loc in locations
301-
])
292+
frac_coords = np.array(
293+
[struct.lattice.get_fractional_coords(loc) for loc in locations]
294+
)
302295

303296
# Compute the maximum fractional radius for the outer cutoff.
304297
# Use the maximum stretch factor of the inverse matrix for non-orthogonal lattices.
@@ -311,9 +304,7 @@ def get(self, struct: PeriodicStructure, site_class: str = None) -> Neighborhood
311304
dim = struct.lattice.dim
312305

313306
# Build boxsize array
314-
boxsize = np.array([
315-
1.0 if periodic[i] else 1e10 for i in range(dim)
316-
])
307+
boxsize = np.array([1.0 if periodic[i] else 1e10 for i in range(dim)])
317308

318309
# Wrap fractional coordinates to [0, 1) for periodic dimensions
319310
frac_coords_wrapped = frac_coords.copy()

src/pylattica/core/runner/asynchronous_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class AsynchronousRunner(Runner):
2727
that this mode should be used with the is_async initialization parameter.
2828
"""
2929

30-
def _run(
30+
def _run( # pylint: disable=too-many-positional-arguments
3131
self,
3232
_: SimulationState,
3333
result: SimulationResult,

src/pylattica/core/runner/synchronous_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, parallel: bool = False, workers: int = None) -> None:
3838
self.parallel = parallel
3939
self.workers = workers
4040

41-
def _run(
41+
def _run( # pylint: disable=too-many-positional-arguments
4242
self,
4343
initial_state: SimulationState,
4444
result: SimulationResult,

0 commit comments

Comments
 (0)