Hi,
Rendering during training is unusually slow on Colab V2-8 TPU VMs, likely because rendering is executed on the CPU. In theory, it seems like it would be possible to render images as JAX arrays on the accelerator (TPU/GPU), which would significantly speed things up.
Could the current render function be adapted to support JAX jitting?
Thanks!