Want to know what I'm doing wrong in implementation here #444
Unanswered
dhruvsreenivas
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone, hope you are doing well! I'm working on a research project with the DeepMind JAX ecosystem (Haiku, Optax), but for some reason, I find that when I train over a dataset, the training loss doesn't go down, as shown in this screenshot.

I'm trying to do something pretty simple: train Random Network Distillation (https://arxiv.org/abs/1810.12894, https://github.com/deepmind/acme/tree/master/acme/agents/jax/rnd) on an offline dataset of D4RL MuJoCo data. I tried a few sanity checks, including training on one random data point for some number of iterations. That loss also doesn't go down: it basically stays at 0.005 for 1000 straight epochs (shown in below screenshots):
Here are some snippets:
RND neural network + trainer code:
Training loop code:
where
self
refers to a workspace with an experiment configcfg
where I train and save everything of interest.As shown, I use the
optax.adam
optimizer with learning rate1e-3
. This I think is standard (maybe a bit large, but I've swept through a few learning rates both larger and smaller to get the same results).I'm wondering where I am going wrong in this training approach--I think I have it correct, but there's something that I'm certainly missing that I don't know about. Any help would be greatly appreciated! If you guys have any additional questions, I'll be happy to send you updates either on here or through a video chat. Also let me know if the Optax repo is the right place to send this msg--I don't think this is an issue yet (more on me than on the package) so I'm putting it in the discussions tab.
Regarding package versions, I am using Haiku 0.0.7, Optax 0.1.3, JAX 0.3.16 on CUDA for these experiments. I love the framework by the way!
Beta Was this translation helpful? Give feedback.
All reactions