Replies: 3 comments 3 replies
-
As far as I know, jax does not support this use-case out of the box. Disclaimer: I'm the author of the library. |
Beta Was this translation helpful? Give feedback.
-
Even though your networks are small, running them on large numbers of data points can create the parallelism needed to take advantage of GPUs or TPUs. So I would definitely try running your code on e.g. a GPU colab (making sure to use |
Beta Was this translation helpful? Give feedback.
-
I have posted similar question, but I think the difference in my case the networks are of different size. Would love to know if you have any tips for this: #7236 |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Dear All,
sorry for beginner question. I would like to ask for some guidance regarding CPU parallelization in JAX.
Background: I am working on so-called planner problem from macroeconomics. It is inf-horizon, discrete time stochastic control problem. I am using neural networks for approximating optimal control functions. As a loss function, I am using negative of average discounted reward across 15k periods and 2.5k initial points. Because of that, each evaluation of loss function is painfully slow, even mini-batch gradients (with 150 points) takes a hell lot of time, hence I want to parallelize this thing.
My networks are relatively small (simple FFW networks with 32 neurons in hidden units), so GPU/TPU wouldn't help (I think). While my loss function include long for loops (accumulating discounted reward over long horizon), I need to evaluate this operation (and its gradient) on large number of points, so I hoped that I can use pmap, or ideally vmap/pmap to distribute this task on multiple CPUs on my university's cluster.
So, my question is, should I expect JAX to work out of the box and register those CPUs just after installation, or would It require more work? Is there some tutorial for this?
Any guidance would be very welcomed!
Best,
Honza
Beta Was this translation helpful? Give feedback.
All reactions