diff --git a/src/lorem/calculator.py b/src/lorem/calculator.py index eb52ea4..25b478f 100644 --- a/src/lorem/calculator.py +++ b/src/lorem/calculator.py @@ -7,9 +7,10 @@ from ase.calculators.calculator import ( BaseCalculator, PropertyNotImplementedError, - compare_atoms, ) +from lorem.neighborlist import NeighborListCache + class Calculator(BaseCalculator): name = "lorem" @@ -35,12 +36,16 @@ def __init__( stress=False, add_offset=True, double_precision=False, + skin=0.25, ): self.params = params self.cutoff = cutoff + self.skin = skin self.add_offset = add_offset self.double_precision = double_precision + self._nl_cache = NeighborListCache(skin=skin) + if not stress: self.implemented_properties = ["born_effective_charges", "energy", "forces"] @@ -91,20 +96,73 @@ def from_checkpoint( return cls(model.predict, species_to_weight, params, model.cutoff, **kwargs) def update(self, atoms): - changes = compare_atoms(self.atoms, atoms) - - if len(changes) > 0: + if self._nl_cache.needs_update(atoms): + # Structural change or combined displacement beyond skin self.results = {} self.atoms = atoms.copy() self.setup(atoms) + elif self.atoms is None or not self._geometry_unchanged(atoms): + # Positions and/or cell changed but within skin budget + self.results = {} + self.atoms = atoms.copy() + self._update_geometry(atoms) + + def _geometry_unchanged(self, atoms): + return np.array_equal( + atoms.get_positions(), self.atoms.get_positions() + ) and np.array_equal(atoms.get_cell()[:], self.atoms.get_cell()[:]) def setup(self, atoms): from lorem.batching import to_batch, to_sample - sample = to_sample(atoms, self.cutoff, energy=False, forces=False, stress=False) + nl_cutoff = self.cutoff + self.skin + + # Derive Ewald parameters from physical cutoff so the long-range + # decomposition is unchanged when using the extended cutoff. + lr_wavelength = self.cutoff / 8.0 + smearing = lr_wavelength * 2.0 + + sample = to_sample( + atoms, + nl_cutoff, + lr_wavelength=lr_wavelength, + smearing=smearing, + energy=False, + forces=False, + stress=False, + ) batch = to_batch([sample], []) self.batch = jax.tree.map(lambda x: jnp.array(x), batch) + max_cell_shift = int(np.abs(np.array(self.batch.sr.cell_shifts)).max()) + self._nl_cache.save_reference(atoms, max_cell_shift=max_cell_shift) + + def _update_geometry(self, atoms): + """Update positions and cell in cached batch without rebuilding + the neighbor list. The model recomputes R_ij from the updated + sr.positions and sr.cell, and the Ewald calculator recomputes + k-vectors from sr.cell (pbc.k_grid stores only integer frequency + indices). So forces, energy, and stress remain correct.""" + sr = self.batch.sr + n_atoms = len(atoms) + + positions = np.zeros(np.array(sr.positions).shape, dtype=np.float32) + positions[:n_atoms] = atoms.get_positions() + + cell = np.array(sr.cell) + new_cell = atoms.get_cell()[:].astype(np.float32) + if atoms.get_pbc().sum() == 2: + from jaxpme.batched_mixed.batching import shrink_2d_cell + + new_cell = shrink_2d_cell(new_cell, atoms.get_pbc(), positions[:n_atoms]) + cell[0] = new_cell + + new_sr = sr._replace( + positions=jnp.array(positions), + cell=jnp.array(cell), + ) + self.batch = self.batch._replace(sr=new_sr) + def calculate( self, atoms=None, diff --git a/src/lorem/ipi.py b/src/lorem/ipi.py index 0e28d0a..1b55398 100644 --- a/src/lorem/ipi.py +++ b/src/lorem/ipi.py @@ -13,15 +13,18 @@ class LOREM_driver(ASEDriver): - def __init__(self, template, model_path, *args, **kwargs): + def __init__(self, template, model_path, *args, skin=0.25, **kwargs): self.model_path = model_path + self.skin = skin super().__init__(template, *args, **kwargs) self.capabilities.append("born_effective_charges") def check_parameters(self): super().check_parameters() has_stress = "stress" in self.capabilities - self.ase_calculator = Calculator.from_checkpoint(self.model_path, stress=has_stress) + self.ase_calculator = Calculator.from_checkpoint( + self.model_path, stress=has_stress, skin=self.skin + ) def compute_structure(self, cell, pos): pot_ipi, force_ipi, vir_ipi, extras = super().compute_structure(cell, pos) diff --git a/src/lorem/neighborlist.py b/src/lorem/neighborlist.py new file mode 100644 index 0000000..53537cc --- /dev/null +++ b/src/lorem/neighborlist.py @@ -0,0 +1,111 @@ +"""Verlet-style neighbor list cache. + +Builds neighbor lists with cutoff + skin and reuses them as long as the +combined position displacement and cell deformation stays within the skin +budget. This avoids expensive neighbor searches on every MD step while +guaranteeing that all pairs within the physical cutoff are present in the +cached list. + +For a pair (i, j) with periodic image shift S, the change in pairwise +distance from the reference is bounded by: + + |dR_ij| <= |dR_i| + |dR_j| + |S . d_cell| + <= 2 * d_max + max_shift * sum(|d_cell_A|) + +The neighbor list remains valid as long as this is < skin. +""" + +import numpy as np + + +class NeighborListCache: + """Cache for neighbor lists with skin-based recomputation. + + The Verlet criterion ensures that a neighbor list built with + cutoff + skin contains all pairs within cutoff even after atomic + displacements and cell deformations, as long as the combined + change stays within the skin budget. + + Parameters + ---------- + skin : float + Skin distance in Angstrom. Default 0.25. + """ + + def __init__(self, skin=0.25): + self.skin = skin + self._reference_positions = None + self._reference_cell = None + self._reference_pbc = None + self._reference_numbers = None + self._max_cell_shift = None + + def needs_update(self, atoms): + """Check if neighbor list needs recomputation. + + Returns True on first call, on any structural change (pbc, + natoms, atomic numbers), or when the combined position + displacement + cell deformation exceeds the skin budget. + + When max_cell_shift is not set (direct use without calculator), + falls back to exact cell comparison for backward compatibility. + """ + if self._reference_positions is None: + return True + + if len(atoms) != len(self._reference_positions): + return True + if (atoms.get_atomic_numbers() != self._reference_numbers).any(): + return True + if (atoms.get_pbc() != self._reference_pbc).any(): + return True + + # Cell handling depends on whether we have cell shift info + if self._max_cell_shift is None: + # No shift info — exact cell comparison (conservative) + if (atoms.get_cell()[:] != self._reference_cell).any(): + return True + + # Position displacement: max over atoms of |dR| + displacements = atoms.get_positions() - self._reference_positions + max_disp = np.sqrt((displacements**2).sum(axis=1).max()) + + # Cell deformation contribution + if self._max_cell_shift is not None and self._max_cell_shift > 0: + # |S . d_cell| <= max_shift * sum_A(|d_cell_A|) + # where |d_cell_A| is the norm of the change in cell vector A + cell_change = atoms.get_cell()[:] - self._reference_cell + cell_vector_norms = np.linalg.norm(cell_change, axis=1) + max_cell_contrib = self._max_cell_shift * cell_vector_norms.sum() + else: + max_cell_contrib = 0.0 + + # Combined criterion: + # max |dR_ij| <= 2*d_max + max_cell_contrib < skin + return bool(2 * max_disp + max_cell_contrib > self.skin) + + def save_reference(self, atoms, max_cell_shift=None): + """Store reference state after neighbor list rebuild. + + Parameters + ---------- + atoms : ase.Atoms + Reference atomic configuration. + max_cell_shift : int or None + Maximum absolute value of any cell shift component in the + neighbor list. Enables the combined position + cell Verlet + criterion. When None, falls back to exact cell comparison. + """ + self._reference_positions = atoms.get_positions().copy() + self._reference_cell = np.array(atoms.get_cell()[:]).copy() + self._reference_pbc = atoms.get_pbc().copy() + self._reference_numbers = atoms.get_atomic_numbers().copy() + self._max_cell_shift = max_cell_shift + + def reset(self): + """Clear the cache.""" + self._reference_positions = None + self._reference_cell = None + self._reference_pbc = None + self._reference_numbers = None + self._max_cell_shift = None diff --git a/tests/test_calculator.py b/tests/test_calculator.py index ca8d31a..d1cfbde 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -47,6 +47,246 @@ def test_calculator_bec_get_property(): assert bec.shape == (len(atoms), 3, 3) +def test_calculator_skin_default(): + """Calculator defaults to skin=0.25.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + calc = Calculator.from_model(model) + assert calc.skin == 0.25 + + +def test_calculator_skin_configurable(): + """Skin parameter is passed through from_model.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + calc = Calculator.from_model(model, skin=0.5) + assert calc.skin == 0.5 + + +def test_calculator_skin_results_match_no_skin(): + """Calculator with skin gives same results as skin=0 on static structure.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc_skin = Calculator.from_model(model, skin=0.25) + calc_skin.calculate(atoms) + + calc_no_skin = Calculator.from_model(model, skin=0.0) + calc_no_skin.calculate(atoms) + + np.testing.assert_allclose( + calc_skin.results["energy"], + calc_no_skin.results["energy"], + atol=1e-5, + ) + np.testing.assert_allclose( + calc_skin.results["forces"], + calc_no_skin.results["forces"], + atol=1e-5, + ) + + +def test_calculator_skin_reuses_neighborlist(): + """Small displacement reuses cached neighbor list, results stay correct.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.5) + calc.calculate(atoms) + + # Small displacement within skin/2 + displaced = atoms.copy() + pos = displaced.get_positions() + pos[0] += 0.05 # 0.087 Å < 0.25 Å = 0.5*skin + displaced.set_positions(pos) + + # Should not trigger full rebuild (check via cache state) + assert calc._nl_cache.needs_update(displaced) is False + + calc.calculate(displaced) + assert "energy" in calc.results + assert "forces" in calc.results + + # Compare with fresh calculation on displaced structure + calc_fresh = Calculator.from_model(model, skin=0.5) + calc_fresh.calculate(displaced) + + np.testing.assert_allclose( + calc.results["energy"], + calc_fresh.results["energy"], + atol=1e-5, + ) + np.testing.assert_allclose( + calc.results["forces"], + calc_fresh.results["forces"], + atol=1e-5, + ) + + +def test_calculator_skin_forces_after_displacement(): + """Forces are correct after position-only update (no full rebuild).""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + skin = 0.5 + + calc = Calculator.from_model(model, skin=skin) + calc.calculate(atoms) + energy_before = calc.results["energy"] + + # Displace and recalculate — should use position-only update + displaced = atoms.copy() + pos = displaced.get_positions() + pos[0, 0] += 0.1 + displaced.set_positions(pos) + + calc.calculate(displaced) + energy_after = calc.results["energy"] + + # Energy should change (atoms moved) + assert energy_before != energy_after + + # Forces should have correct shape + assert calc.results["forces"].shape == (len(atoms), 3) + + +def test_calculator_skin_full_rebuild_on_large_displacement(): + """Large displacement triggers full neighbor list rebuild.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.4) + calc.calculate(atoms) + + # Large displacement exceeding skin/2 + displaced = atoms.copy() + pos = displaced.get_positions() + pos[0, 0] += 0.3 # > 0.5 * 0.4 = 0.2 Å + displaced.set_positions(pos) + + assert calc._nl_cache.needs_update(displaced) is True + calc.calculate(displaced) + assert "energy" in calc.results + + +def test_calculator_cell_change_within_skin(): + """Small cell change (NPT-like) reuses cached neighbor list. + + cell₀ cell₀ * 1.001 + ┌──────────┐ ┌───────────┐ + │ · · · · │ 0.1% scale │ · · · · │ + │ · · · · │ ─────────> │ · · · · │ same neighbor + │ · · · · │ │ · · · · │ list topology + └──────────┘ └───────────┘ + """ + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.5) + calc.calculate(atoms) + + # Small isotropic cell scaling (~0.1%) + scaled = atoms.copy() + scaled.set_cell(atoms.get_cell() * 1.001, scale_atoms=True) + + # Should NOT trigger full rebuild + assert calc._nl_cache.needs_update(scaled) is False + + calc.calculate(scaled) + + # Compare with fresh calculation on scaled structure + calc_fresh = Calculator.from_model(model, skin=0.5) + calc_fresh.calculate(scaled) + + np.testing.assert_allclose( + calc.results["energy"], + calc_fresh.results["energy"], + atol=1e-5, + ) + np.testing.assert_allclose( + calc.results["forces"], + calc_fresh.results["forces"], + atol=1e-5, + ) + + +def test_calculator_cell_change_stress_correct(): + """Stress is correct after geometry-only update (no full rebuild). + + Stress = ∂E/∂ε depends on sr.cell through: + σ = Σ_i R_i ⊗ ∂E/∂R_i + Σ_A cell_A ⊗ ∂E/∂cell_A + + Both terms use current sr.positions and sr.cell, so stress + is correct as long as those are updated. + """ + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.5, stress=True) + calc.calculate(atoms) + + scaled = atoms.copy() + scaled.set_cell(atoms.get_cell() * 1.002, scale_atoms=True) + + assert calc._nl_cache.needs_update(scaled) is False + calc.calculate(scaled) + + calc_fresh = Calculator.from_model(model, skin=0.5, stress=True) + calc_fresh.calculate(scaled) + + np.testing.assert_allclose( + calc.results["stress"], + calc_fresh.results["stress"], + atol=1e-5, + ) + + +def test_calculator_large_cell_change_triggers_rebuild(): + """Large cell change exceeds combined Verlet criterion.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.5) + calc.calculate(atoms) + + scaled = atoms.copy() + scaled.set_cell(atoms.get_cell() * 1.1, scale_atoms=True) + + assert calc._nl_cache.needs_update(scaled) is True + calc.calculate(scaled) + assert "energy" in calc.results + + +def test_calculator_combined_position_and_cell_change(): + """Both position and cell change within skin — correct results.""" + model = Lorem(cutoff=5.0, num_features=8, num_spherical_features=2, num_radial=4) + atoms = bulk("Ar") * [2, 2, 2] + + calc = Calculator.from_model(model, skin=0.5) + calc.calculate(atoms) + + modified = atoms.copy() + modified.set_cell(atoms.get_cell() * 1.001, scale_atoms=True) + pos = modified.get_positions() + pos[0, 0] += 0.05 + modified.set_positions(pos) + + assert calc._nl_cache.needs_update(modified) is False + + calc.calculate(modified) + + calc_fresh = Calculator.from_model(model, skin=0.5) + calc_fresh.calculate(modified) + + np.testing.assert_allclose( + calc.results["energy"], + calc_fresh.results["energy"], + atol=1e-5, + ) + np.testing.assert_allclose( + calc.results["forces"], + calc_fresh.results["forces"], + atol=1e-5, + ) + + def test_lorem_calculator_import(): from lorem import LOREMCalculator diff --git a/tests/test_neighborlist.py b/tests/test_neighborlist.py new file mode 100644 index 0000000..4279274 --- /dev/null +++ b/tests/test_neighborlist.py @@ -0,0 +1,278 @@ +import pytest +from ase.build import bulk + +from lorem.neighborlist import NeighborListCache + + +@pytest.fixture +def ar_bulk(): + return bulk("Ar") * [2, 2, 2] + + +@pytest.fixture +def cache(): + return NeighborListCache(skin=0.5) + + +def test_needs_update_first_call(cache, ar_bulk): + """First call always needs update (no reference stored).""" + assert cache.needs_update(ar_bulk) is True + + +def test_needs_update_no_change(cache, ar_bulk): + """No change after saving reference — no update needed.""" + cache.save_reference(ar_bulk) + assert cache.needs_update(ar_bulk) is False + + +def test_needs_update_small_displacement(cache, ar_bulk): + """Displacement below 0.5*skin — no update needed.""" + cache.save_reference(ar_bulk) + + displaced = ar_bulk.copy() + # Move one atom by 0.1 Å (< 0.5 * 0.5 = 0.25 Å) + pos = displaced.get_positions() + pos[0, 0] += 0.1 + displaced.set_positions(pos) + + assert cache.needs_update(displaced) is False + + +def test_needs_update_large_displacement(cache, ar_bulk): + """Displacement above 0.5*skin — update needed.""" + cache.save_reference(ar_bulk) + + displaced = ar_bulk.copy() + # Move one atom by 0.3 Å (> 0.5 * 0.5 = 0.25 Å) + pos = displaced.get_positions() + pos[0, 0] += 0.3 + displaced.set_positions(pos) + + assert cache.needs_update(displaced) is True + + +def test_needs_update_boundary_displacement(cache, ar_bulk): + """Displacement exactly at 0.5*skin boundary — no update needed (strict >).""" + cache.save_reference(ar_bulk) + + displaced = ar_bulk.copy() + pos = displaced.get_positions() + # Exactly 0.25 Å displacement: (0.25)^2 = 0.0625, threshold = 0.0625 + # strict >, so this should NOT trigger update + pos[0, 0] += 0.25 + displaced.set_positions(pos) + + assert cache.needs_update(displaced) is False + + +def test_needs_update_cell_change_no_shift_info(cache, ar_bulk): + """Without max_cell_shift info, any cell change triggers update.""" + cache.save_reference(ar_bulk) # no max_cell_shift → exact comparison + + modified = ar_bulk.copy() + cell = modified.get_cell() + cell[0, 0] += 0.01 + modified.set_cell(cell) + + assert cache.needs_update(modified) is True + + +def test_needs_update_pbc_change(cache, ar_bulk): + """PBC change triggers update.""" + cache.save_reference(ar_bulk) + + modified = ar_bulk.copy() + modified.set_pbc([True, True, False]) + + assert cache.needs_update(modified) is True + + +def test_needs_update_natoms_change(cache, ar_bulk): + """Number of atoms change triggers update.""" + cache.save_reference(ar_bulk) + + smaller = bulk("Ar") * [2, 2, 1] + assert cache.needs_update(smaller) is True + + +def test_needs_update_numbers_change(cache, ar_bulk): + """Atomic numbers change triggers update.""" + cache.save_reference(ar_bulk) + + modified = ar_bulk.copy() + numbers = modified.get_atomic_numbers() + numbers[0] = 36 # Change Ar (18) to Kr (36) + modified.set_atomic_numbers(numbers) + + assert cache.needs_update(modified) is True + + +def test_save_reference_and_reset(cache, ar_bulk): + """save_reference stores state, reset clears it.""" + cache.save_reference(ar_bulk) + assert cache.needs_update(ar_bulk) is False + + cache.reset() + assert cache.needs_update(ar_bulk) is True + + +def test_cumulative_displacement(ar_bulk): + """Cumulative displacement from reference is what matters.""" + cache = NeighborListCache(skin=0.4) + cache.save_reference(ar_bulk) + + displaced = ar_bulk.copy() + pos = displaced.get_positions() + + # First displacement: 0.15 Å (< 0.5 * 0.4 = 0.2 Å) — ok + pos[0, 0] += 0.15 + displaced.set_positions(pos) + assert cache.needs_update(displaced) is False + + # Cumulative displacement: 0.25 Å (> 0.2 Å) — needs update + pos[0, 0] += 0.10 + displaced.set_positions(pos) + assert cache.needs_update(displaced) is True + + +def test_default_skin(): + """Default skin is 0.25 Å.""" + cache = NeighborListCache() + assert cache.skin == 0.25 + + +# -- Cell-change tests (combined position + cell criterion) -- + + +def test_needs_update_small_cell_change_with_shift(ar_bulk): + """Small cell change with max_cell_shift=1 stays within skin. + + cell₀ (reference) cell = cell₀ * 1.001 + ┌──────────┐ ┌───────────┐ + │ · · · · │ 0.1% scaling │ · · · · │ + │ · · · · │ ───────────> │ · · · · │ + │ · · · · │ Δcell ≈ 0.005 │ · · · · │ + └──────────┘ └───────────┘ + max_shift * Σ|Δcell_A| ≈ 0.016 Å << skin = 0.5 Å + """ + cache = NeighborListCache(skin=0.5) + cache.save_reference(ar_bulk, max_cell_shift=1) + + scaled = ar_bulk.copy() + scaled.set_cell(ar_bulk.get_cell() * 1.001, scale_atoms=True) + + assert cache.needs_update(scaled) is False + + +def test_needs_update_large_cell_change_with_shift(ar_bulk): + """Large cell change with max_cell_shift=1 exceeds skin. + + 10% cell scaling: Δcell_A ≈ 0.5 Å per vector + max_shift * Σ|Δcell_A| ≈ 1 * 1.5 = 1.5 Å >> skin = 0.5 Å + """ + cache = NeighborListCache(skin=0.5) + cache.save_reference(ar_bulk, max_cell_shift=1) + + scaled = ar_bulk.copy() + scaled.set_cell(ar_bulk.get_cell() * 1.1, scale_atoms=True) + + assert cache.needs_update(scaled) is True + + +def test_needs_update_cell_change_zero_shift(ar_bulk): + """With max_cell_shift=0, cell changes have no effect on criterion. + + Non-periodic systems have all cell_shifts = [0,0,0], so cell + deformation doesn't change any pairwise distance. + """ + cache = NeighborListCache(skin=0.5) + cache.save_reference(ar_bulk, max_cell_shift=0) + + modified = ar_bulk.copy() + cell = modified.get_cell() + cell[0, 0] += 1.0 # large change, but shift=0 → no contribution + modified.set_cell(cell) + + assert cache.needs_update(modified) is False + + +def test_needs_update_combined_position_and_cell(ar_bulk): + """Position + cell change, each small, combined within skin. + + d_max = 0.05 Å → 2*d_max = 0.1 Å + Δcell ≈ 0.005 Å → max_shift * Σ|Δcell| ≈ 0.016 Å + total ≈ 0.116 Å < skin = 0.5 Å → no rebuild + """ + cache = NeighborListCache(skin=0.5) + cache.save_reference(ar_bulk, max_cell_shift=1) + + modified = ar_bulk.copy() + modified.set_cell(ar_bulk.get_cell() * 1.001, scale_atoms=True) + pos = modified.get_positions() + pos[0, 0] += 0.05 + modified.set_positions(pos) + + assert cache.needs_update(modified) is False + + +def test_needs_update_combined_exceeds_skin(ar_bulk): + """Position + cell change that individually are small but combined exceed skin. + + skin = 0.2 Å + d_max = 0.06 Å → 2*d_max = 0.12 Å + Δcell ≈ 0.15 Å → max_shift * Σ|Δcell| ≈ 0.09 Å + total ≈ 0.21 Å > skin = 0.2 Å → rebuild! + """ + cache = NeighborListCache(skin=0.2) + cache.save_reference(ar_bulk, max_cell_shift=1) + + modified = ar_bulk.copy() + # Cell change: scale by 1.01 → Δcell ≈ 0.05 per vector, 3 vectors + modified.set_cell(ar_bulk.get_cell() * 1.01, scale_atoms=True) + pos = modified.get_positions() + pos[0, 0] += 0.06 + modified.set_positions(pos) + + assert cache.needs_update(modified) is True + + +def test_needs_update_higher_cell_shift(ar_bulk): + """Higher max_cell_shift amplifies cell deformation contribution. + + For Ar FCC 2×2×2 (cell vectors ≈ 7.4 Å each), a 1% scaling gives: + - Position d_max ≈ 0.09 Å (atom farthest from origin) + - |Δcell_A| ≈ 0.074 Å per vector, Σ ≈ 0.22 Å + + Combined criterion: 2*d_max + max_shift * Σ|Δcell_A| + - shift=1: 0.18 + 0.22 = 0.40 < 0.5 → ok + - shift=3: 0.18 + 0.67 = 0.85 > 0.5 → rebuild + """ + scaled = ar_bulk.copy() + scaled.set_cell(ar_bulk.get_cell() * 1.01, scale_atoms=True) + + # With shift=1: within skin + cache1 = NeighborListCache(skin=0.5) + cache1.save_reference(ar_bulk, max_cell_shift=1) + assert cache1.needs_update(scaled) is False + + # With shift=3: same deformation, amplified → exceeds skin + cache3 = NeighborListCache(skin=0.5) + cache3.save_reference(ar_bulk, max_cell_shift=3) + assert cache3.needs_update(scaled) is True + + +def test_reset_clears_max_cell_shift(ar_bulk): + """Reset clears max_cell_shift, reverting to exact cell comparison.""" + cache = NeighborListCache(skin=0.5) + cache.save_reference(ar_bulk, max_cell_shift=1) + + scaled = ar_bulk.copy() + scaled.set_cell(ar_bulk.get_cell() * 1.001, scale_atoms=True) + + # With shift info: small change tolerated + assert cache.needs_update(scaled) is False + + # After reset + re-save without shift info: exact comparison + cache.reset() + cache.save_reference(ar_bulk) + assert cache.needs_update(scaled) is True