Skip to content

Commit 897711f

Browse files
authored
Miscellaneous fixes (#282)
This PR addresses the following small but unrelated issues: - [x] Missed edge cases not handled by #280 (71b5fa4). - [x] Some recent combination of changes on either or both CuPy and Jax makes `cupy.asarray(x)` for `x: JaxArray` return a copy instead of a view (134d085). - [x] Upcoming compatibility changes introduced by SSAGESLabs/hoomd-dlext/pull/23 (71991b5). - [x] Some fonts were intended to be loaded by #281, but currently fail (e6533cd).
1 parent 896f684 commit 897711f

File tree

8 files changed

+72
-62
lines changed

8 files changed

+72
-62
lines changed

docs/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ Make sure you have [GNU make](https://www.gpu.org/software/make/) installed.
1212
You can build the documentation from you local copy of the PySAGES repository as follows:
1313

1414
```shell
15-
pip install -r docs/requirements.txt
15+
cd docs
16+
pip install -r requirements.txt
1617
make html
1718
```
1819

docs/source/_templates/layout.html

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +0,0 @@
1-
{% extends "!layout.html" %} {% block extrahead %}
2-
<link
3-
rel="stylesheet"
4-
href="https://fonts.googleapis.com/css?family=Atkinson Hyperlegible|Montserrat"
5-
/>
6-
{% endblock %}

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
# # so a file named "default.css" will overwrite the builtin "default.css".
101101
html_static_path = ["_static"]
102102
html_css_files = [
103+
"https://fonts.googleapis.com/css?family=Atkinson Hyperlegible|Montserrat",
103104
"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/fontawesome.min.css",
104105
"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/solid.min.css",
105106
"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.2/css/brands.min.css",

pysages/backends/hoomd.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from jax import jit
2222
from jax import numpy as np
23-
from jax.dlpack import from_dlpack as asarray
23+
from jax.dlpack import from_dlpack
2424

2525
from pysages.backends.core import SamplingContext
2626
from pysages.backends.snapshot import (
@@ -61,11 +61,15 @@ def remove_half_step_hook(context):
6161
context.integrator.cpp_integrator.removeHalfStepHook()
6262

6363
else:
64+
if hasattr(hoomd.dlext, "__version__"):
65+
SamplerBase = DLExtSampler
6466

65-
class SamplerBase(DLExtSampler, md.HalfStepHook):
66-
def __init__(self, sysview, update, location, mode):
67-
md.HalfStepHook.__init__(self)
68-
DLExtSampler.__init__(self, sysview, update, location, mode)
67+
else:
68+
69+
class SamplerBase(DLExtSampler, md.HalfStepHook):
70+
def __init__(self, sysview, update, location, mode):
71+
md.HalfStepHook.__init__(self)
72+
DLExtSampler.__init__(self, sysview, update, location, mode)
6973

7074
def is_on_gpu(context):
7175
return not isinstance(context.device, hoomd.device.CPU)
@@ -125,11 +129,11 @@ def snapshot_callback(positions, vel_mass, rtags, images, forces, n):
125129

126130
def _pack_snapshot(self, positions, vel_mass, forces, rtags, images):
127131
return Snapshot(
128-
asarray(positions),
129-
asarray(vel_mass),
130-
asarray(forces),
131-
asarray(rtags),
132-
asarray(images),
132+
from_dlpack(positions),
133+
from_dlpack(vel_mass),
134+
from_dlpack(forces),
135+
from_dlpack(rtags),
136+
from_dlpack(images),
133137
self.box,
134138
self.dt,
135139
)
@@ -149,11 +153,11 @@ def default_location():
149153
def take_snapshot(sampling_context, location=default_location()):
150154
context = sampling_context.context
151155
sysview = sampling_context.view
152-
positions = copy(asarray(positions_types(sysview, location, AccessMode.Read)))
153-
vel_mass = copy(asarray(velocities_masses(sysview, location, AccessMode.Read)))
154-
forces = copy(asarray(net_forces(sysview, location, AccessMode.ReadWrite)))
155-
ids = copy(asarray(rtags(sysview, location, AccessMode.Read)))
156-
imgs = copy(asarray(images(sysview, location, AccessMode.Read)))
156+
positions = copy(from_dlpack(positions_types(sysview, location, AccessMode.Read)))
157+
vel_mass = copy(from_dlpack(velocities_masses(sysview, location, AccessMode.Read)))
158+
forces = copy(from_dlpack(net_forces(sysview, location, AccessMode.ReadWrite)))
159+
ids = copy(from_dlpack(rtags(sysview, location, AccessMode.Read)))
160+
imgs = copy(from_dlpack(images(sysview, location, AccessMode.Read)))
157161

158162
check_device_array(positions) # currently, we only support `DeviceArray`s
159163

@@ -200,17 +204,14 @@ def masses(snapshot):
200204

201205

202206
def build_helpers(context, sampling_method):
207+
utils = importlib.import_module(".utils", package="pysages.backends")
208+
203209
# Depending on the device being used we need to use either cupy or numpy
204210
# (or numba) to generate a view of jax's DeviceArrays
205211
if is_on_gpu(context):
206-
cupy = importlib.import_module("cupy")
207-
view = cupy.asarray
208-
209-
def sync_forces():
210-
cupy.cuda.get_current_stream().synchronize()
212+
sync_forces, view = utils.cupy_helpers()
211213

212214
else:
213-
utils = importlib.import_module(".utils", package="pysages.backends")
214215
view = utils.view
215216

216217
def sync_forces():

pysages/backends/lammps.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from jax import jit
1515
from jax import numpy as np
1616
from jax import vmap
17-
from jax.dlpack import from_dlpack as asarray
17+
from jax.dlpack import from_dlpack
1818
from lammps import dlext
1919
from lammps.dlext import ExecutionSpace, FixDLExt, LAMMPSView, has_kokkos_cuda_enabled
2020

@@ -84,16 +84,16 @@ def update(timestep):
8484
self.set_callback(update)
8585

8686
def _partial_snapshot(self, include_masses: bool = False):
87-
positions = asarray(dlext.positions(self.view, self.location))
88-
types = asarray(dlext.types(self.view, self.location))
89-
velocities = asarray(dlext.velocities(self.view, self.location))
90-
forces = asarray(dlext.forces(self.view, self.location))
91-
tags_map = asarray(dlext.tags_map(self.view, self.location))
92-
imgs = asarray(dlext.images(self.view, self.location))
87+
positions = from_dlpack(dlext.positions(self.view, self.location))
88+
types = from_dlpack(dlext.types(self.view, self.location))
89+
velocities = from_dlpack(dlext.velocities(self.view, self.location))
90+
forces = from_dlpack(dlext.forces(self.view, self.location))
91+
tags_map = from_dlpack(dlext.tags_map(self.view, self.location))
92+
imgs = from_dlpack(dlext.images(self.view, self.location))
9393

9494
masses = None
9595
if include_masses:
96-
masses = asarray(dlext.masses(self.view, self.location))
96+
masses = from_dlpack(dlext.masses(self.view, self.location))
9797
vel_mass = (velocities, (masses, types))
9898

9999
return Snapshot(positions, vel_mass, forces, tags_map, imgs, None, None)
@@ -127,19 +127,15 @@ def build_helpers(context, sampling_method, on_gpu, restore_fn):
127127
"""
128128
Builds helper methods used for restoring snapshots and biasing a simulation.
129129
"""
130+
utils = importlib.import_module(".utils", package="pysages.backends")
130131
dim = context.extract_setting("dimension")
131132

132133
# Depending on the device being used we need to use either cupy or numpy
133134
# (or numba) to generate a view of jax's DeviceArrays
134135
if on_gpu:
135-
cupy = importlib.import_module("cupy")
136-
view = cupy.asarray
137-
138-
def sync_forces():
139-
cupy.cuda.get_current_stream().synchronize()
136+
sync_forces, view = utils.cupy_helpers()
140137

141138
else:
142-
utils = importlib.import_module(".utils", package="pysages.backends")
143139
view = utils.view
144140

145141
def sync_forces():

pysages/backends/openmm.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import openmm_dlext as dlext
99
from jax import jit
1010
from jax import numpy as np
11-
from jax.dlpack import from_dlpack as asarray
11+
from jax.dlpack import from_dlpack
1212
from jax.lax import cond
1313
from openmm_dlext import ContextView, DeviceType, Force
1414

@@ -60,15 +60,15 @@ def take_snapshot(sampling_context):
6060
context = sampling_context.context.context # extra indirection for OpenMM
6161
context_view = sampling_context.view
6262

63-
positions = asarray(dlext.positions(context_view))
64-
forces = asarray(dlext.forces(context_view))
65-
ids = asarray(dlext.atom_ids(context_view))
63+
positions = from_dlpack(dlext.positions(context_view))
64+
forces = from_dlpack(dlext.forces(context_view))
65+
ids = from_dlpack(dlext.atom_ids(context_view))
6666

67-
velocities = asarray(dlext.velocities(context_view))
67+
velocities = from_dlpack(dlext.velocities(context_view))
6868
if is_on_gpu(context_view):
6969
vel_mass = velocities
7070
else:
71-
inverse_masses = asarray(dlext.inverse_masses(context_view))
71+
inverse_masses = from_dlpack(dlext.inverse_masses(context_view))
7272
vel_mass = (velocities, inverse_masses.reshape((-1, 1)))
7373

7474
check_device_array(positions) # currently, we only support `DeviceArray`s
@@ -126,29 +126,21 @@ def momenta(snapshot):
126126

127127

128128
def build_helpers(context, sampling_method):
129+
utils = importlib.import_module(".utils", package="pysages.backends")
130+
129131
# Depending on the device being used we need to use either cupy or numpy
130132
# (or numba) to generate a view of jax's DeviceArrays
131133
if is_on_gpu(context):
132-
cupy = importlib.import_module("cupy")
133-
view = cupy.asarray
134-
135134
restore_vm = _restore_vm
136-
137-
def sync_forces():
138-
cupy.cuda.get_current_stream().synchronize()
135+
sync_forces, view = utils.cupy_helpers()
139136

140137
@jit
141138
def adapt(biases):
142139
return np.int64(2**32 * biases.T)
143140

144141
else:
145-
utils = importlib.import_module(".utils", package="pysages.backends")
146-
view = utils.view
147-
148142
adapt = identity
149-
150-
def sync_forces():
151-
pass
143+
view = utils.view
152144

153145
def restore_vm(view, snapshot, prev_snapshot):
154146
# TODO: Check if we can omit modifying the masses
@@ -158,6 +150,9 @@ def restore_vm(view, snapshot, prev_snapshot):
158150
velocities[:] = view(prev_snapshot.vel_mass[0])
159151
masses[:] = view(prev_snapshot.vel_mass[1])
160152

153+
def sync_forces():
154+
pass
155+
161156
def bias(snapshot, state, sync_backend):
162157
"""Adds the computed bias to the forces."""
163158
if state.bias is None:

pysages/backends/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
33

44
import ctypes
5+
import importlib
56

67
import numba
78
import numpy
@@ -11,6 +12,26 @@
1112
from pysages.utils import dispatch
1213

1314

15+
def cupy_helpers():
16+
"""Returns two methods:
17+
18+
`sync` -- for synchronizing the current CUDA stream
19+
`view` -- to wrap a `JaxArray` as a `cupy.ndarray`
20+
"""
21+
cupy = importlib.import_module("cupy")
22+
dlpack = importlib.import_module("jax.dlpack")
23+
24+
def _sync():
25+
"""Synchronizes the current cupy's CUDA stream."""
26+
cupy.cuda.get_current_stream().synchronize()
27+
28+
def _view(x: JaxArray):
29+
"""Wraps a view of `x: JaxArray` as a `cupy.ndarray`."""
30+
return cupy.from_dlpack(dlpack.to_dlpack(x))
31+
32+
return _sync, _view
33+
34+
1435
@dispatch
1536
def view(array: JaxArray):
1637
"""Return a writable view of a JAX DeviceArray."""

pysages/utils/compat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def has_method(fn, T, index):
6767
else:
6868
_bt = import_module("beartype.door")
6969
_typing = import_module("plum" if _plum_version_tuple < (2, 2, 1) else "typing")
70+
_util = _typing.type if _plum_version_tuple < (2, 2, 1) else _typing
7071

7172
def dispatch_table(dispatch):
7273
return dispatch.functions
@@ -75,8 +76,8 @@ def has_method(fn, T, index):
7576
types_at_index = set()
7677
for sig in fn.methods:
7778
typ = sig.types[index]
78-
if _typing.get_origin(typ) is _typing.Union:
79-
types_at_index.update(_typing.get_args(typ))
79+
if _util.get_origin(typ) is _typing.Union:
80+
types_at_index.update(_util.get_args(typ))
8081
else:
8182
types_at_index.add(typ)
8283
return T in types_at_index

0 commit comments

Comments
 (0)