Skip to content

Possibility of metapackage for selecting CUDA build variant? #334

@matthewfeickert

Description

@matthewfeickert

Comment:

If a user wants to get a cuda build variant of jaxlib, when working with Pixi it is (generally) fine to just do (on a machine with an NVIDIA GPU)

$ pixi init jax-gpu-example && cd jax-gpu-example
$ pixi workspace system-requirements add cuda 12
$ pixi add jax
$ pixi list jax
Package  Version  Build                      Size     Kind   Source
jax      0.7.2    pyhd8ed1ab_0               1.8 MiB  conda  https://conda.anaconda.org/conda-forge/
jaxlib   0.7.2    cuda129_py314h80c6225_202  166 MiB  conda  https://conda.anaconda.org/conda-forge/

though there can be situations in which the resolution ends up with the cpu build variant, but there's no way (that I know of) to check this before install without explicitly requesting a cuda build variant of jaxlib

pixi add jax 'jaxlib[build=cuda*]'

Would it be possible to create a metapackage that would select the cuda build variant of jaxlib but at the jax level? So something similar to the pytorch-gpu metapackage?

I appreciate that maintaining jaxlib is difficult enough as is and so asking for additional features right now might not be in scope, but I was curious if this had been considered before and if there were other reasons to not want to do this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions