Skip to content

Commit 13980f6

Browse files
committed
Move images to the new Snapshot.extras attribute
1 parent e0dc9a0 commit 13980f6

File tree

6 files changed

+14
-16
lines changed

6 files changed

+14
-16
lines changed

pysages/backends/ase.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def take_snapshot(simulation, forces=None):
112112
origin = (0.0, 0.0, 0.0)
113113
dt = simulation.dt
114114

115-
# ASE doesn't use images explicitely
116-
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
115+
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)
117116

118117

119118
def _calculator_defaults(sig, arg, default=[]):

pysages/backends/hoomd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
151151
from_dlpack(vel_mass),
152152
from_dlpack(forces),
153153
from_dlpack(rtags),
154-
from_dlpack(images),
155154
self.update_box(),
156155
self.dt,
156+
dict(images=from_dlpack(images)), # extras
157157
)
158158

159159
# NOTE: The order of the callbacks arguments do not match that of the `Snapshot` attributes
@@ -178,7 +178,7 @@ def build_snapshot_methods(sampling_method):
178178

179179
def positions(snapshot):
180180
L = np.diag(snapshot.box.H)
181-
return snapshot.positions[:, :3] + L * snapshot.images
181+
return snapshot.positions[:, :3] + L * snapshot.extras["images"]
182182

183183
else:
184184

pysages/backends/lammps.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,15 @@ def _partial_snapshot(self, include_masses: bool = False):
9292
velocities = from_dlpack(dlext.velocities(self.view, self.location))
9393
forces = from_dlpack(dlext.forces(self.view, self.location))
9494
tags_map = from_dlpack(dlext.tags_map(self.view, self.location))
95-
imgs = from_dlpack(dlext.images(self.view, self.location))
95+
images = from_dlpack(dlext.images(self.view, self.location))
9696

9797
masses = None
9898
if include_masses:
9999
masses = from_dlpack(dlext.masses(self.view, self.location))
100100
vel_mass = (velocities, (masses, types))
101+
extras = dict(images=images)
101102

102-
return Snapshot(positions, vel_mass, forces, tags_map, imgs, None, None)
103+
return Snapshot(positions, vel_mass, forces, tags_map, None, None, extras)
103104

104105
def _update_snapshot(self):
105106
s = self._partial_snapshot()
@@ -109,7 +110,7 @@ def _update_snapshot(self):
109110
box = self._update_box()
110111
dt = self.snapshot.dt
111112

112-
return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], s.images, box, dt)
113+
return Snapshot(s.positions, vel_mass, s.forces, s.ids[1:], box, dt, s.extras)
113114

114115
def restore(self, prev_snapshot):
115116
"""Replaces this sampler's snapshot with `prev_snapshot`."""
@@ -122,7 +123,7 @@ def take_snapshot(self):
122123
dt = get_timestep(self.context)
123124

124125
return Snapshot(
125-
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], copy(s.images), box, dt
126+
copy(s.positions), copy(s.vel_mass), copy(s.forces), s.ids[1:], box, dt, copy(s.extras)
126127
)
127128

128129

@@ -198,7 +199,7 @@ def unpack(image):
198199

199200
def positions(snapshot):
200201
L = np.diag(snapshot.box.H)
201-
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.images)
202+
return snapshot.positions[:, :3] + L * vmap(unpack)(snapshot.extras["images"])
202203

203204
else:
204205

pysages/backends/openmm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ def _take_snapshot(self):
8686
origin = (0.0, 0.0, 0.0)
8787
dt = context.getIntegrator().getStepSize() / unit.picosecond
8888

89-
# OpenMM doesn't have images
90-
return Snapshot(positions, vel_mass, forces, ids, None, Box(H, origin), dt)
89+
return Snapshot(positions, vel_mass, forces, ids, Box(H, origin), dt)
9190

9291

9392
def is_on_gpu(view: ContextView):

pysages/backends/snapshot.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class Snapshot(NamedTuple):
4141
vel_mass: Union[JaxArray, Tuple[JaxArray, JaxArray]]
4242
forces: JaxArray
4343
ids: JaxArray
44-
images: Optional[JaxArray]
4544
box: Box
4645
dt: Union[JaxArray, float]
4746
extras: Optional[Dict[str, Any]] = None
@@ -91,9 +90,9 @@ def restore(view, snapshot, prev_snapshot, restore_vm=restore_vm):
9190
# Special handling for velocities and masses
9291
restore_vm(view, snapshot, prev_snapshot)
9392
# Overwrite images if the backend uses them
94-
if snapshot.images is not None:
95-
images = view(snapshot.images)
96-
images[:] = view(prev_snapshot.images)
93+
if hasattr(snapshot.extras, "images"):
94+
images = view(snapshot.extras["images"])
95+
images[:] = view(prev_snapshot.extras["images"])
9796

9897

9998
def build_data_querier(snapshot_methods, flags):

tests/test_snapshots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_copying():
1919
box = Box(H, origin)
2020
dt = 0.1
2121

22-
old = Snapshot(positions, vel_mass, forces, ids, None, box, dt)
22+
old = Snapshot(positions, vel_mass, forces, ids, box, dt)
2323
new = copy(old)
2424

2525
old_ptr = old.positions.unsafe_buffer_pointer()

0 commit comments

Comments
 (0)