-
I'm learning jax by implementing my own version of Dreamer (see here). To be able to jit to whole training loop (sample batch -> update model, actor, critic -> sample batch ->....), I implemented the replay buffer with jax arrays. So, for instance, given an image observation of shape (64, 64, 3), a pre-defined episode length of 1000 steps and a capacity of (say) 50 episodes, I allocate an array in the following way: {
'observation': jax.device_put(jnp.full((capacity, max_episode_length + 1) + (64, 64, 3), jnp.nan, jnp.uint8), device),
...
} where ...
self.data['observation'] = self.data['observation'].at[self.idx, position].set(data)
... Unfortunately, this seems like a very slow operation, and what's more weird for me is that the bigger the capacity is, the slower this operation becomes. I've created two gists that demonstrate this behavior:
I created the second gist not because I wanted to compare insertion times between jax and numpy (this is already explained nicely on the FAQ), but mostly because I wasn't sure if the numpy script would also run slower when increasing the buffer's capacity. Any thoughts about this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Outside JIT, insertion operations like Inside JIT, XLA is able to avoid these copies by essentially compiling them into in-place operations. So I'd suggest JIT-compiling your function to improve the performance. |
Beta Was this translation helpful? Give feedback.
Outside JIT, insertion operations like
x = x.at[y].set(z)
result in a copy of the full buffer. So it makes sense that (1) it would be slower than numpy's in-place insertions, and (2) it would be slower for larger buffers than for smaller buffers.Inside JIT, XLA is able to avoid these copies by essentially compiling them into in-place operations.
So I'd suggest JIT-compiling your function to improve the performance.