You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my application, I have multiple training workers (as Ray actors) that interact with non-JAX environments. The training workers use a JAX-based policy that outputs not only an action, but also other statistics during the policy evaluation. The output statistics along with environment outputs are stacked together and sent to an other Ray actor for a non-JAX post-processing procedure.
This process roughly looks like the following
defpost_process(traj):
# do something in pure Python using numpyreturnprocessed@ray.remoteclassReplay:
def__init__(self):
self.processed= []
defadd(self, traj):
self.processed.extend(post_process(traj))
@ray.remoteclassDataWorker:
defrun(self):
# interact with environment using a JAX policytraj= [(step.env.out, step.policy.statistics) forstepinsteps]
# step.policy.statistics are DeviceArray sreturntrajdefmain():
replay.add.remote(data_worker.run.remote())
The project and I am having is that the post processing is somehow really slow and is currently the bottleneck of the system. I am wondering if any one has experience in handling JAX device arrays across processes, especially when working with Ray? If so, do you have tips for building a system like this?
Here are some specific questions I would like to ask:
Should the data worker manually transfer the data back to CPU using device_put before letting Ray serializing the data?
If I rewrite post_process function in JAX, does that mean we never need to transfer the device arrays back to host, potentially saving a lot of time?
The post_process function process trajectories with variable length. If I rewrite it in JAX, I can image the JIT-compiled version of post_process being recompiled over 200 times (for each specific length the computation pattern is the same). Is this the right way to JAX?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
In my application, I have multiple training workers (as Ray actors) that interact with non-JAX environments. The training workers use a JAX-based policy that outputs not only an action, but also other statistics during the policy evaluation. The output statistics along with environment outputs are stacked together and sent to an other Ray actor for a non-JAX post-processing procedure.
This process roughly looks like the following
The project and I am having is that the post processing is somehow really slow and is currently the bottleneck of the system. I am wondering if any one has experience in handling JAX device arrays across processes, especially when working with Ray? If so, do you have tips for building a system like this?
Here are some specific questions I would like to ask:
device_put
before letting Ray serializing the data?post_process
function in JAX, does that mean we never need to transfer the device arrays back to host, potentially saving a lot of time?post_process
function process trajectories with variable length. If I rewrite it in JAX, I can image the JIT-compiled version ofpost_process
being recompiled over 200 times (for each specific length the computation pattern is the same). Is this the right way to JAX?Beta Was this translation helpful? Give feedback.
All reactions