Skip to content

Commit 3de447b

Browse files
authored
Add support for optional Snapshot attributes (#365)
Add support for optional Snapshot attributes
2 parents cbfc56d + 9f3d935 commit 3de447b

23 files changed

+125
-82
lines changed

pysages/backends/ase.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def take_snapshot(simulation, forces=None):
112112
origin = (0.0, 0.0, 0.0)
113113
dt = simulation.dt
114114

115-
# ASE doesn't use images explicitely
116-
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
115+
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)
117116

118117

119118
def _calculator_defaults(sig, arg, default=[]):

pysages/backends/hoomd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
151151
from_dlpack(vel_mass),
152152
from_dlpack(forces),
153153
from_dlpack(rtags),
154-
from_dlpack(images),
155154
self.update_box(),
156155
self.dt,
156+
dict(images=from_dlpack(images)), # extras
157157
)
158158

159159
# NOTE: The order of the callbacks arguments do not match that of the `Snapshot` attributes
@@ -178,7 +178,7 @@ def build_snapshot_methods(sampling_method):
178178

179179
def positions(snapshot):
180180
L = np.diag(snapshot.box.H)
181-
return snapshot.positions[:, :3] + L * snapshot.images
181+
return snapshot.positions[:, :3] + L * snapshot.extras["images"]
182182

183183
else:
184184

pysages/backends/lammps.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,15 @@ def _partial_snapshot(self, include_masses: bool = False):
9292
velocities = from_dlpack(dlext.velocities(self.view, self.location))
9393
forces = from_dlpack(dlext.forces(self.view, self.location))
9494
tags_map = from_dlpack(dlext.tags_map(self.view, self.location))
95-
imgs = from_dlpack(dlext.images(self.view, self.location))
95+
images = from_dlpack(dlext.images(self.view, self.location))
9696

9797
masses = None
9898
if include_masses:
9999
masses = from_dlpack(dlext.masses(self.view, self.location))
100100
vel_mass = (velocities, (masses, types))
101+
extras = dict(images=images)
101102

102-
return Snapshot(positions, vel_mass, forces, tags_map, imgs, None, None)
103+
return Snapshot(positions, vel_mass, forces, tags_map, None, None, extras)
103104

104105
def _update_snapshot(self):
105106
s = self._partial_snapshot()
@@ -109,7 +110,7 @@ def _update_snapshot(self):
109110
box = self._update_box()
110111
dt = self.snapshot.dt
111112

112-
return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], s.images, box, dt)
113+
return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], box, dt, s.extras)
113114

114115
def restore(self, prev_snapshot):
115116
"""Replaces this sampler's snapshot with `prev_snapshot`."""
@@ -122,7 +123,7 @@ def take_snapshot(self):
122123
dt = get_timestep(self.context)
123124

124125
return Snapshot(
125-
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], copy(s.images), box, dt
126+
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], box, dt, copy(s.extras)
126127
)
127128

128129

@@ -198,7 +199,7 @@ def unpack(image):
198199

199200
def positions(snapshot):
200201
L = np.diag(snapshot.box.H)
201-
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.images)
202+
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.extras["images"])
202203

203204
else:
204205

pysages/backends/openmm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def _take_snapshot(self):
8686
origin = (0.0, 0.0, 0.0)
8787
dt = context.getIntegrator().getStepSize() / unit.picosecond
8888

89-
# OpenMM doesn't have images
90-
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
89+
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)
9190

9291

9392
def is_on_gpu(view: ContextView):

pysages/backends/snapshot.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,16 @@
44
from jax import jit
55
from jax import numpy as np
66

7-
from pysages.typing import Callable, JaxArray, NamedTuple, Optional, Tuple, Union
7+
from pysages.typing import (
8+
Any,
9+
Callable,
10+
Dict,
11+
JaxArray,
12+
NamedTuple,
13+
Optional,
14+
Tuple,
15+
Union,
16+
)
817
from pysages.utils import copy, dispatch, identity
918

1019
AbstractBox = NamedTuple("AbstractBox", [("H", JaxArray), ("origin", JaxArray)])
@@ -32,9 +41,9 @@ class Snapshot(NamedTuple):
3241
vel_mass: Union[JaxArray, Tuple[JaxArray, JaxArray]]
3342
forces: JaxArray
3443
ids: JaxArray
35-
images: Optional[JaxArray]
3644
box: Box
3745
dt: Union[JaxArray, float]
46+
extras: Optional[Dict[str, Any]] = None
3847

3948
def __repr__(self):
4049
return "PySAGES " + type(self).__name__
@@ -81,9 +90,9 @@ def restore(view, snapshot, prev_snapshot, restore_vm=restore_vm):
8190
# Special handling for velocities and masses
8291
restore_vm(view, snapshot, prev_snapshot)
8392
# Overwrite images if the backend uses them
84-
if snapshot.images is not None:
85-
images = view(snapshot.images)
86-
images[:] = view(prev_snapshot.images)
93+
if hasattr(snapshot.extras, "images"):
94+
images = view(snapshot.extras["images"])
95+
images[:] = view(prev_snapshot.extras["images"])
8796

8897

8998
def build_data_querier(snapshot_methods, flags):

pysages/methods/abf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class ABFState(NamedTuple):
6969
force: JaxArray
7070
Wp: JaxArray
7171
Wp_: JaxArray
72-
ncalls: int
72+
ncalls: int = 0
7373

7474
def __repr__(self):
7575
return repr("PySAGES " + type(self).__name__)
@@ -187,7 +187,7 @@ def initialize():
187187
force = np.zeros(dims)
188188
Wp = np.zeros(dims)
189189
Wp_ = np.zeros(dims)
190-
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_, 0)
190+
return ABFState(xi, bias, hist, Fsum, force, Wp, Wp_)
191191

192192
def update(state, data):
193193
"""

pysages/methods/ann.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ANNState(NamedTuple):
7070
phi: JaxArray
7171
prob: JaxArray
7272
nn: NNData
73-
ncalls: int
73+
ncalls: int = 0
7474

7575
def __repr__(self):
7676
return repr("PySAGES " + type(self).__name__)
@@ -148,7 +148,7 @@ def initialize():
148148
phi = np.zeros(shape)
149149
prob = np.ones(shape)
150150
nn = NNData(ps, np.array(0.0), np.array(1.0))
151-
return ANNState(xi, bias, hist, phi, prob, nn, 0)
151+
return ANNState(xi, bias, hist, phi, prob, nn)
152152

153153
def update(state, data):
154154
ncalls = state.ncalls + 1

pysages/methods/cff.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class CFFState(NamedTuple):
9393
Wp_: JaxArray
9494
nn: NNData
9595
fnn: NNData
96-
ncalls: int
96+
ncalls: int = 0
9797

9898
def __repr__(self):
9999
return repr("PySAGES " + type(self).__name__)
@@ -218,7 +218,7 @@ def initialize():
218218
nn = NNData(ps, np.array(0.0), np.array(1.0))
219219
fnn = NNData(fps, np.zeros(dims), np.array(1.0))
220220

221-
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn, 0)
221+
return CFFState(xi, bias, hist, histp, prob, fe, Fsum, force, Wp, Wp_, nn, fnn)
222222

223223
def update(state, data):
224224
# During the intial stage, when there are not enough collected samples, use ABF

pysages/methods/ffs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
class FFSState(NamedTuple):
2525
xi: JaxArray
2626
bias: Optional[JaxArray]
27-
ncalls: int
27+
ncalls: int = 0
2828

2929
def __repr__(self):
3030
return repr("PySAGES " + type(self).__name__)
@@ -211,7 +211,7 @@ def _ffs(method, snapshot, helpers):
211211
# initialize method
212212
def initialize():
213213
xi = cv(helpers.query(snapshot))
214-
return FFSState(xi, None, 0)
214+
return FFSState(xi, None)
215215

216216
def update(state, data):
217217
xi = cv(data)

pysages/methods/funn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class FUNNState(NamedTuple):
7878
Wp: JaxArray
7979
Wp_: JaxArray
8080
nn: NNData
81-
ncalls: int
81+
ncalls: int = 0
8282

8383
def __repr__(self):
8484
return repr("PySAGES " + type(self).__name__)
@@ -182,7 +182,7 @@ def initialize():
182182
Wp = np.zeros(dims)
183183
Wp_ = np.zeros(dims)
184184
nn = NNData(ps, F, F)
185-
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn, 0)
185+
return FUNNState(xi, bias, hist, Fsum, F, Wp, Wp_, nn)
186186

187187
def update(state, data):
188188
# During the intial stage, when there are not enough collected samples, use ABF

0 commit comments

Comments
 (0)