2020)
2121from jax import jit
2222from jax import numpy as np
23- from jax .dlpack import from_dlpack as asarray
23+ from jax .dlpack import from_dlpack
2424
2525from pysages .backends .core import SamplingContext
2626from pysages .backends .snapshot import (
@@ -61,11 +61,15 @@ def remove_half_step_hook(context):
6161 context .integrator .cpp_integrator .removeHalfStepHook ()
6262
6363else :
64+ if hasattr (hoomd .dlext , "__version__" ):
65+ SamplerBase = DLExtSampler
6466
65- class SamplerBase (DLExtSampler , md .HalfStepHook ):
66- def __init__ (self , sysview , update , location , mode ):
67- md .HalfStepHook .__init__ (self )
68- DLExtSampler .__init__ (self , sysview , update , location , mode )
67+ else :
68+
69+ class SamplerBase (DLExtSampler , md .HalfStepHook ):
70+ def __init__ (self , sysview , update , location , mode ):
71+ md .HalfStepHook .__init__ (self )
72+ DLExtSampler .__init__ (self , sysview , update , location , mode )
6973
7074 def is_on_gpu (context ):
7175 return not isinstance (context .device , hoomd .device .CPU )
@@ -125,11 +129,11 @@ def snapshot_callback(positions, vel_mass, rtags, images, forces, n):
125129
126130 def _pack_snapshot (self , positions , vel_mass , forces , rtags , images ):
127131 return Snapshot (
128- asarray (positions ),
129- asarray (vel_mass ),
130- asarray (forces ),
131- asarray (rtags ),
132- asarray (images ),
132+ from_dlpack (positions ),
133+ from_dlpack (vel_mass ),
134+ from_dlpack (forces ),
135+ from_dlpack (rtags ),
136+ from_dlpack (images ),
133137 self .box ,
134138 self .dt ,
135139 )
@@ -149,11 +153,11 @@ def default_location():
149153def take_snapshot (sampling_context , location = default_location ()):
150154 context = sampling_context .context
151155 sysview = sampling_context .view
152- positions = copy (asarray (positions_types (sysview , location , AccessMode .Read )))
153- vel_mass = copy (asarray (velocities_masses (sysview , location , AccessMode .Read )))
154- forces = copy (asarray (net_forces (sysview , location , AccessMode .ReadWrite )))
155- ids = copy (asarray (rtags (sysview , location , AccessMode .Read )))
156- imgs = copy (asarray (images (sysview , location , AccessMode .Read )))
156+ positions = copy (from_dlpack (positions_types (sysview , location , AccessMode .Read )))
157+ vel_mass = copy (from_dlpack (velocities_masses (sysview , location , AccessMode .Read )))
158+ forces = copy (from_dlpack (net_forces (sysview , location , AccessMode .ReadWrite )))
159+ ids = copy (from_dlpack (rtags (sysview , location , AccessMode .Read )))
160+ imgs = copy (from_dlpack (images (sysview , location , AccessMode .Read )))
157161
158162 check_device_array (positions ) # currently, we only support `DeviceArray`s
159163
@@ -200,17 +204,14 @@ def masses(snapshot):
200204
201205
202206def build_helpers (context , sampling_method ):
207+ utils = importlib .import_module (".utils" , package = "pysages.backends" )
208+
203209 # Depending on the device being used we need to use either cupy or numpy
204210 # (or numba) to generate a view of jax's DeviceArrays
205211 if is_on_gpu (context ):
206- cupy = importlib .import_module ("cupy" )
207- view = cupy .asarray
208-
209- def sync_forces ():
210- cupy .cuda .get_current_stream ().synchronize ()
212+ sync_forces , view = utils .cupy_helpers ()
211213
212214 else :
213- utils = importlib .import_module (".utils" , package = "pysages.backends" )
214215 view = utils .view
215216
216217 def sync_forces ():
0 commit comments