Replies: 2 comments 1 reply
-
|
I haven't tried this before, but it's not clear how adding the network to the env will hook-up to training. Does a feature extractor or action post-processor need to be part of the environment? Code pointer in brax https://github.com/google/brax/blob/ab34392416af8a40045934a0ee02206babd34857/brax/training/networks.py#L368 |
Beta Was this translation helpful? Give feedback.
-
|
The benefit of integrating it to the environment would be that you don't need to alter/reimplement the RL training scripts. Since these NNs would be frozen/not trained, then there is no need to expose their parameters beyond the environment. As usual, once I start to make a MRE the issue changes. With small MLP type networks this approach seems to work easily without issues! I'll share my not-quite minimal UV project in case someone wants to try something similar: The issue persists with convolutional layers baked in the env. I'll keep investigating and give an update here if I figure it out! |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
I'm seeking general advice for using a fixed-parameter, pre-trained neural network as part of a playground environment, e.g. a feature extractor that processes the state for observations, or an action post-processor, etc.
I tried this by loading the parameters into a flax model in the environments
init. I the also stored theapply()of the model as an environment member variable. I figured this would let me use it withinstep.Even though I'm using a tiny NN, performing inference with a fixed NN inside
stepleads to silent crashes (no error message, just return code 9 or 1, WSL crashes completely, and I have to reboot). Are there any working examples of this approach in playground/brax environments online?If not, I can create a minimal reproduction of my crashes and discuss it here.
System info:
Running on WSL
Training with brax PPO
python 3.10
brax 0.13.0
jax 0.6.2
jax-cuda12-pjrt 0.6.2
jax-cuda12-plugin 0.6.2
jaxlib 0.6.2
jaxopt 0.8.5
msgpack 1.1.1
mujoco 3.3.5
mujoco-mjx 3.3.5
nvidia-cublas-cu12 12.9.1.4
nvidia-cuda-cupti-cu12 12.9.79
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12 9.13.0.50
nvidia-cufft-cu12 11.4.1.4
nvidia-cusolver-cu12 11.7.5.82
nvidia-cusparse-cu12 12.5.10.65
nvidia-nccl-cu12 2.28.3
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvshmem-cu12 3.4.5
opt-einsum 3.4.0
optax 0.2.5
playground 0.0.5
Beta Was this translation helpful? Give feedback.
All reactions