You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Without jit I can use device_put/device_get to force different calculations to happen on different devices, but I'd like to be able to jit the entire thing (this is library code that the uesr may want to apply arbitrary other transforms to). Also in general y1, y2, y3,... will have different shapes so some of the usual tricks won't work.
Has anyone dealt with this design pattern before or have any ideas?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
JAX has great support for SIMD/SPMD parallelism through
vmap
,pmap
,xmap
,shmap
etc, but I'm wondering if the converse is possible.Basically, I have the following situation:
Without
jit
I can usedevice_put
/device_get
to force different calculations to happen on different devices, but I'd like to be able to jit the entire thing (this is library code that the uesr may want to apply arbitrary other transforms to). Also in generaly1, y2, y3,...
will have different shapes so some of the usual tricks won't work.Has anyone dealt with this design pattern before or have any ideas?
Beta Was this translation helpful? Give feedback.
All reactions