Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pysages/backends/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ def take_snapshot(simulation, forces=None):
origin = (0.0, 0.0, 0.0)
dt = simulation.dt

# ASE doesn't use images explicitely
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)


def _calculator_defaults(sig, arg, default=[]):
Expand Down
4 changes: 2 additions & 2 deletions pysages/backends/hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
from_dlpack(vel_mass),
from_dlpack(forces),
from_dlpack(rtags),
from_dlpack(images),
self.update_box(),
self.dt,
dict(images=from_dlpack(images)), # extras
)

# NOTE: The order of the callbacks arguments do not match that of the `Snapshot` attributes
Expand All @@ -178,7 +178,7 @@ def build_snapshot_methods(sampling_method):

def positions(snapshot):
L = np.diag(snapshot.box.H)
return snapshot.positions[:, :3] + L * snapshot.images
return snapshot.positions[:, :3] + L * snapshot.extras["images"]

else:

Expand Down
11 changes: 6 additions & 5 deletions pysages/backends/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ def _partial_snapshot(self, include_masses: bool = False):
velocities = from_dlpack(dlext.velocities(self.view, self.location))
forces = from_dlpack(dlext.forces(self.view, self.location))
tags_map = from_dlpack(dlext.tags_map(self.view, self.location))
imgs = from_dlpack(dlext.images(self.view, self.location))
images = from_dlpack(dlext.images(self.view, self.location))

masses = None
if include_masses:
masses = from_dlpack(dlext.masses(self.view, self.location))
vel_mass = (velocities, (masses, types))
extras = dict(images=images)

return Snapshot(positions, vel_mass, forces, tags_map, imgs, None, None)
return Snapshot(positions, vel_mass, forces, tags_map, None, None, extras)

def _update_snapshot(self):
s = self._partial_snapshot()
Expand All @@ -109,7 +110,7 @@ def _update_snapshot(self):
box = self._update_box()
dt = self.snapshot.dt

return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], s.images, box, dt)
return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], box, dt, s.extras)

def restore(self, prev_snapshot):
"""Replaces this sampler's snapshot with `prev_snapshot`."""
Expand All @@ -122,7 +123,7 @@ def take_snapshot(self):
dt = get_timestep(self.context)

return Snapshot(
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], copy(s.images), box, dt
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], box, dt, copy(s.extras)
)


Expand Down Expand Up @@ -198,7 +199,7 @@ def unpack(image):

def positions(snapshot):
L = np.diag(snapshot.box.H)
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.images)
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.extras["images"])

else:

Expand Down
3 changes: 1 addition & 2 deletions pysages/backends/openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def _take_snapshot(self):
origin = (0.0, 0.0, 0.0)
dt = context.getIntegrator().getStepSize() / unit.picosecond

# OpenMM doesn't have images
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)


def is_on_gpu(view: ContextView):
Expand Down
19 changes: 14 additions & 5 deletions pysages/backends/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
from jax import jit
from jax import numpy as np

from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union
from pysages.typing import (
Any,
Callable,
Dict,
JaxArray,
NamedTuple,
Optional,
Tuple,
Union,
)
from pysages.utils import copy, dispatch, identity

AbstractBox = NamedTuple("AbstractBox", [("H", JaxArray), ("origin", JaxArray)])
Expand Down Expand Up @@ -32,9 +41,9 @@ class Snapshot(NamedTuple):
vel_mass: Union[JaxArray, Tuple[JaxArray, JaxArray]]
forces: JaxArray
ids: JaxArray
images: Optional[JaxArray]
box: Box
dt: Union[JaxArray, float]
extras: Optional[Dict[str, Any]] = None

def __repr__(self):
return "PySAGES " + type(self).__name__
Expand Down Expand Up @@ -81,9 +90,9 @@ def restore(view, snapshot, prev_snapshot, restore_vm=restore_vm):
# Special handling for velocities and masses
restore_vm(view, snapshot, prev_snapshot)
# Overwrite images if the backend uses them
if snapshot.images is not None:
images = view(snapshot.images)
images[:] = view(prev_snapshot.images)
if hasattr(snapshot.extras, "images"):
images = view(snapshot.extras["images"])
images[:] = view(prev_snapshot.extras["images"])


def build_data_querier(snapshot_methods, flags):
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ABFState(NamedTuple):
force: JaxArray
Wp: JaxArray
Wp_: JaxArray
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -187,7 +187,7 @@ def initialize():
force = np.zeros(dims)
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_, 0)
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_)

def update(state, data):
"""
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ANNState(NamedTuple):
phi: JaxArray
prob: JaxArray
nn: NNData
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -148,7 +148,7 @@ def initialize():
phi = np.zeros(shape)
prob = np.ones(shape)
nn = NNData(ps, np.array(0.0), np.array(1.0))
return ANNState(xi, bias, hist, phi, prob, nn, 0)
return ANNState(xi, bias, hist, phi, prob, nn)

def update(state, data):
ncalls = state.ncalls + 1
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class CFFState(NamedTuple):
Wp_: JaxArray
nn: NNData
fnn: NNData
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -218,7 +218,7 @@ def initialize():
nn = NNData(ps, np.array(0.0), np.array(1.0))
fnn = NNData(fps, np.zeros(dims), np.array(1.0))

return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 0)
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/ffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class FFSState(NamedTuple):
xi: JaxArray
bias: Optional[JaxArray]
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -211,7 +211,7 @@ def _ffs(method, snapshot, helpers):
# initialize method
def initialize():
xi = cv(helpers.query(snapshot))
return FFSState(xi, None, 0)
return FFSState(xi, None)

def update(state, data):
xi = cv(data)
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/funn.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class FUNNState(NamedTuple):
Wp: JaxArray
Wp_: JaxArray
nn: NNData
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -182,7 +182,7 @@ def initialize():
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
nn = NNData(ps, F, F)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 0)
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/harmonic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class HarmonicBiasState(NamedTuple):

xi: JaxArray
bias: JaxArray
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -119,7 +119,7 @@ def _harmonic_bias(method, snapshot, helpers):
def initialize():
xi, _ = cv(helpers.query(snapshot))
bias = np.zeros((natoms, helpers.dimensionality()))
return HarmonicBiasState(xi, bias, 0)
return HarmonicBiasState(xi, bias)

def update(state, data):
xi, Jxi = cv(data)
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/metad.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MetadynamicsState(NamedTuple):
grid_potential: Optional[JaxArray]
grid_gradient: Optional[JaxArray]
idx: int
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -181,7 +181,7 @@ def initialize():
grid_gradient = np.zeros((*shape, shape.size), dtype=np.float64)

return MetadynamicsState(
xi, bias, heights, centers, sigmas, grid_potential, grid_gradient, 0, 0
xi, bias, heights, centers, sigmas, grid_potential, grid_gradient, 0
)

def update(state, data):
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/sirens.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class SirensState(NamedTuple): # pylint: disable=R0903
Wp: JaxArray
Wp_: JaxArray
nn: NNData
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -241,7 +241,7 @@ def initialize():
else:
histp = prob = fe = None

return SirensState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, 0)
return SirensState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn)

def update(state, data):
# During the intial stage, when there are not enough collected samples, use ABF
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/spectral_abf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class SpectralABFState(NamedTuple):
Wp: JaxArray
Wp_: JaxArray
fun: Fun
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES " + type(self).__name__)
Expand Down Expand Up @@ -177,7 +177,7 @@ def initialize():
Wp = np.zeros(dims)
Wp_ = np.zeros(dims)
fun = fit(Fsum)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun, 0)
return SpectralABFState(xi, bias, hist, Fsum, force, Wp, Wp_, fun)

def update(state, data):
# During the intial stage use ABF
Expand Down
4 changes: 2 additions & 2 deletions pysages/methods/unbiased.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class UnbiasedState(NamedTuple):

xi: JaxArray
bias: Optional[JaxArray]
ncalls: int
ncalls: int = 0

def __repr__(self):
return repr("PySAGES" + type(self).__name__)
Expand Down Expand Up @@ -66,7 +66,7 @@ def _unbias(method, snapshot, helpers):

def initialize():
xi = cv(helpers.query(snapshot))
return UnbiasedState(xi, None, 0)
return UnbiasedState(xi, None)

def update(state, data):
xi = cv(data)
Expand Down
Loading
Loading