How to set default device as CPU #10399
Unanswered
zohimchandani
asked this question in
Q&A
Replies: 1 comment 4 replies
-
Option 1 import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# before execute any computation / allocation
print(jax.numpy.ones(3).device()) # TFRT_CPU_0 Option 2 import jax
jax.config.update('jax_platform_name', 'cpu')
# before execute any computation / allocation
print(jax.numpy.ones(3).device()) # TFRT_CPU_0 Option 3 import jax
def f():
return jax.numpy.ones(3)
print(jax.jit(f, backend='cpu')().device()) # TFRT_CPU_0 |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I would like to benchmark the performance of my code on CPU and GPU but Jax natively recognizes the GPU. How can I force it to run the code on CPU?
Beta Was this translation helpful? Give feedback.
All reactions