Skip to content
Discussion options

You must be logged in to vote

Even after JAX has been imported, you can still force it to use the CPU using the following methods:

  1. Set a Default Device for JAX Operations
  • You can explicitly set the default backend to the CPU by updating JAX’s configuration:
import jax

jax.config.update("jax_default_device", jax.devices("cpu")[0])

# Create an array and perform a computation on the CPU
x = jnp.ones((3, 3))
y = jnp.linalg.inv(x + jnp.eye(3))  
print(y.device) 
# TFRT_CPU_0
  1. Manually Specify CPU for Computation
  • If you want to execute specific computations on the CPU while keeping GPU active for other tasks, you can explicitly place tensors and computations on the CPU:
import jax
import jax.numpy as jnp

cpu_device =

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cool-RR
Comment options

Answer selected by cool-RR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants