Can we use vmap without taking data to GPU? #8611
Unanswered
mohamad-amin
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey!
I have a function that I'm applying vmap on it.
Inside the vmap, I have to use
jax.numpy
as it is apparently dealing withTraceShapedArray
s and at some point I'm usingscipy.linalg.cho_solve
which is being called usingTraceShapedArray
s as inputs and it can not accept them. Thus, I have to usejax.scipy.linalg.cho_solve
but using this automatically takes thecho_solve
computation to GPU, even though the input to it is not ajax.numpy
array. This is problematic as my data that gets fed intocho_solve
can not fit in any GPU. Thus, I would like to use vmap but without operating in GPU. Is there anyway to do this?P.S: I need GPUs for some pmap computations before this vmap that doesn't fit on GPU, so I can't force jax to use CPU devices all the time.
Sincerely,
Amin
Beta Was this translation helpful? Give feedback.
All reactions