Precise control of Device with default_device
#11935
Unanswered
joeryjoery
asked this question in
Q&A
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I was trying to get
jax.default_device
working as I want to have more control on how to place computation, but it seems thatjax
has no attribute to this context manager? Even though it would seem it that has to, right? see #9118https://jax.readthedocs.io/en/latest/_autosummary/jax.default_device.html
I was trying to hack around this with setting
jax.jit(..., device=device)
on my desired function, however because I'm using this iteratively, I would get loads of back and forth copying from the global/ default device to the computation device (causing my computer to freeze up and hog RAM that would make even Chrome jealous). I'm aware of the global device setting with e.g.,CUDA_VISIBLE_DEVICE
etc., but a Tensorflow session like context manager would be optimal.Does anyone know if I missed something? I'm using Jax version 0.3.5.
Beta Was this translation helpful? Give feedback.
All reactions