Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 63 additions & 5 deletions src/lorem/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from ase.calculators.calculator import (
BaseCalculator,
PropertyNotImplementedError,
compare_atoms,
)

from lorem.neighborlist import NeighborListCache


class Calculator(BaseCalculator):
name = "lorem"
Expand All @@ -35,12 +36,16 @@ def __init__(
stress=False,
add_offset=True,
double_precision=False,
skin=0.25,
):
self.params = params
self.cutoff = cutoff
self.skin = skin
self.add_offset = add_offset
self.double_precision = double_precision

self._nl_cache = NeighborListCache(skin=skin)

if not stress:
self.implemented_properties = ["born_effective_charges", "energy", "forces"]

Expand Down Expand Up @@ -91,20 +96,73 @@ def from_checkpoint(
return cls(model.predict, species_to_weight, params, model.cutoff, **kwargs)

def update(self, atoms):
changes = compare_atoms(self.atoms, atoms)

if len(changes) > 0:
if self._nl_cache.needs_update(atoms):
# Structural change or combined displacement beyond skin
self.results = {}
self.atoms = atoms.copy()
self.setup(atoms)
elif self.atoms is None or not self._geometry_unchanged(atoms):
# Positions and/or cell changed but within skin budget
self.results = {}
self.atoms = atoms.copy()
self._update_geometry(atoms)

def _geometry_unchanged(self, atoms):
return np.array_equal(
atoms.get_positions(), self.atoms.get_positions()
) and np.array_equal(atoms.get_cell()[:], self.atoms.get_cell()[:])

def setup(self, atoms):
from lorem.batching import to_batch, to_sample

sample = to_sample(atoms, self.cutoff, energy=False, forces=False, stress=False)
nl_cutoff = self.cutoff + self.skin

# Derive Ewald parameters from physical cutoff so the long-range
# decomposition is unchanged when using the extended cutoff.
lr_wavelength = self.cutoff / 8.0
smearing = lr_wavelength * 2.0

sample = to_sample(
atoms,
nl_cutoff,
lr_wavelength=lr_wavelength,
smearing=smearing,
energy=False,
forces=False,
stress=False,
)
batch = to_batch([sample], [])
self.batch = jax.tree.map(lambda x: jnp.array(x), batch)

max_cell_shift = int(np.abs(np.array(self.batch.sr.cell_shifts)).max())
self._nl_cache.save_reference(atoms, max_cell_shift=max_cell_shift)

def _update_geometry(self, atoms):
"""Update positions and cell in cached batch without rebuilding
the neighbor list. The model recomputes R_ij from the updated
sr.positions and sr.cell, and the Ewald calculator recomputes
k-vectors from sr.cell (pbc.k_grid stores only integer frequency
indices). So forces, energy, and stress remain correct."""
sr = self.batch.sr
n_atoms = len(atoms)

positions = np.zeros(np.array(sr.positions).shape, dtype=np.float32)
positions[:n_atoms] = atoms.get_positions()

cell = np.array(sr.cell)
new_cell = atoms.get_cell()[:].astype(np.float32)
if atoms.get_pbc().sum() == 2:
from jaxpme.batched_mixed.batching import shrink_2d_cell

new_cell = shrink_2d_cell(new_cell, atoms.get_pbc(), positions[:n_atoms])
cell[0] = new_cell

new_sr = sr._replace(
positions=jnp.array(positions),
cell=jnp.array(cell),
)
self.batch = self.batch._replace(sr=new_sr)

def calculate(
self,
atoms=None,
Expand Down
7 changes: 5 additions & 2 deletions src/lorem/ipi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@


class LOREM_driver(ASEDriver):
def __init__(self, template, model_path, *args, **kwargs):
def __init__(self, template, model_path, *args, skin=0.25, **kwargs):
self.model_path = model_path
self.skin = skin
super().__init__(template, *args, **kwargs)
self.capabilities.append("born_effective_charges")

def check_parameters(self):
super().check_parameters()
has_stress = "stress" in self.capabilities
self.ase_calculator = Calculator.from_checkpoint(self.model_path, stress=has_stress)
self.ase_calculator = Calculator.from_checkpoint(
self.model_path, stress=has_stress, skin=self.skin
)

def compute_structure(self, cell, pos):
pot_ipi, force_ipi, vir_ipi, extras = super().compute_structure(cell, pos)
Expand Down
111 changes: 111 additions & 0 deletions src/lorem/neighborlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Verlet-style neighbor list cache.

Builds neighbor lists with cutoff + skin and reuses them as long as the
combined position displacement and cell deformation stays within the skin
budget. This avoids expensive neighbor searches on every MD step while
guaranteeing that all pairs within the physical cutoff are present in the
cached list.

For a pair (i, j) with periodic image shift S, the change in pairwise
distance from the reference is bounded by:

|dR_ij| <= |dR_i| + |dR_j| + |S . d_cell|
<= 2 * d_max + max_shift * sum(|d_cell_A|)

The neighbor list remains valid as long as this is < skin.
"""

import numpy as np


class NeighborListCache:
"""Cache for neighbor lists with skin-based recomputation.

The Verlet criterion ensures that a neighbor list built with
cutoff + skin contains all pairs within cutoff even after atomic
displacements and cell deformations, as long as the combined
change stays within the skin budget.

Parameters
----------
skin : float
Skin distance in Angstrom. Default 0.25.
"""

def __init__(self, skin=0.25):
self.skin = skin
self._reference_positions = None
self._reference_cell = None
self._reference_pbc = None
self._reference_numbers = None
self._max_cell_shift = None

def needs_update(self, atoms):
"""Check if neighbor list needs recomputation.

Returns True on first call, on any structural change (pbc,
natoms, atomic numbers), or when the combined position
displacement + cell deformation exceeds the skin budget.

When max_cell_shift is not set (direct use without calculator),
falls back to exact cell comparison for backward compatibility.
"""
if self._reference_positions is None:
return True

if len(atoms) != len(self._reference_positions):
return True
if (atoms.get_atomic_numbers() != self._reference_numbers).any():
return True
if (atoms.get_pbc() != self._reference_pbc).any():
return True

# Cell handling depends on whether we have cell shift info
if self._max_cell_shift is None:
# No shift info — exact cell comparison (conservative)
if (atoms.get_cell()[:] != self._reference_cell).any():
return True

# Position displacement: max over atoms of |dR|
displacements = atoms.get_positions() - self._reference_positions
max_disp = np.sqrt((displacements**2).sum(axis=1).max())

# Cell deformation contribution
if self._max_cell_shift is not None and self._max_cell_shift > 0:
# |S . d_cell| <= max_shift * sum_A(|d_cell_A|)
# where |d_cell_A| is the norm of the change in cell vector A
cell_change = atoms.get_cell()[:] - self._reference_cell
cell_vector_norms = np.linalg.norm(cell_change, axis=1)
max_cell_contrib = self._max_cell_shift * cell_vector_norms.sum()
else:
max_cell_contrib = 0.0

# Combined criterion:
# max |dR_ij| <= 2*d_max + max_cell_contrib < skin
return bool(2 * max_disp + max_cell_contrib > self.skin)

def save_reference(self, atoms, max_cell_shift=None):
"""Store reference state after neighbor list rebuild.

Parameters
----------
atoms : ase.Atoms
Reference atomic configuration.
max_cell_shift : int or None
Maximum absolute value of any cell shift component in the
neighbor list. Enables the combined position + cell Verlet
criterion. When None, falls back to exact cell comparison.
"""
self._reference_positions = atoms.get_positions().copy()
self._reference_cell = np.array(atoms.get_cell()[:]).copy()
self._reference_pbc = atoms.get_pbc().copy()
self._reference_numbers = atoms.get_atomic_numbers().copy()
self._max_cell_shift = max_cell_shift

def reset(self):
"""Clear the cache."""
self._reference_positions = None
self._reference_cell = None
self._reference_pbc = None
self._reference_numbers = None
self._max_cell_shift = None
Loading