Skip to content
Discussion options

You must be logged in to vote

You can set the JAX platform to 'cpu', with one of these approaches:

  1. set the JAX_PLATFORM_NAME=cpu environment variable;
  2. put jax.config.update('jax_platform_name', 'cpu') at the top of your main file;
  3. parse command line flags with absl and pass the argument --jax_platform_name=cpu.

Here's an example:

import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
jnp.sin(0.)
$ python issue6023.py
0.0

The reason the warning prints is that by default JAX tries to use a hardware accelerator, rather than defaulting to CPU. If you set the platform to CPU, it won't try to find a hardware accelerator, and so it won't warn you that it couldn't find one.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by thisiscam
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants