Skip to content
Discussion options

You must be logged in to vote

From the code there's a flag to specify the platforms.

flags.DEFINE_string(
    'jax_platforms',
    os.getenv('JAX_PLATFORMS', '').lower(),
    'Comma-separated list of platform names specifying which platforms jax '
    'should attempt to initialize. The first platform in the list that is '
    'successfully initialized will be used as the default platform. For '
    'example, --jax_platforms=cpu,gpu means that CPU and GPU backends will be '
    'initialized, and the CPU backend will be used unless otherwise specified; '
    '--jax_platforms=cpu means that only the CPU backend will be initialized. '
    'By default, jax will try to initialize all available platforms and will '
    'defa…

Replies: 3 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@jvlmdr
Comment options

Answer selected by alonfnt
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants