Pybullet with JAX #9754
-
I have a robot that I want to train using RL in pybullet. After having seen this video on Intro to JAX, I wanted to know how JAX IR works with pybullet. I have a notion in my mind that might not be accurate. This is how I think jax works: It optimizes the code that is written in I have this concern as training in C++ natively would definitely be faster, however, the development is time-consuming. And if we go developing with python anyways, then we need to cut down the bottlenecks. Of course, this is also related to how pybullet handles simulation. With this context, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I don't think pybullet is compatible with JAX. JAX has its own set of operations that it knows how to transform and optimize (implemented in You might also check out https://github.com/google/brax, which is JAX-compatible and seems similar to pybullet (note: I haven't used either library so I may be wrong!) |
Beta Was this translation helpful? Give feedback.
I don't think pybullet is compatible with JAX. JAX has its own set of operations that it knows how to transform and optimize (implemented in
jax.lax
, with numpy-style wrappers injax.numpy
,jax.scipy
, etc.) There is not really any facility to take existing non-JAX libraries and compile them to XLA. If you can find the equivalent of pybullet's operations that someone has implemented on top of JAX (just asjax.numpy
implements the equivalent of numpy's API on top ofjax
), that would be the best, otherwise your only option would be running JAX code and pybullet code independently.You might also check out https://github.com/google/brax, which is JAX-compatible and seems similar to pybullet (…