-
-
Notifications
You must be signed in to change notification settings - Fork 29
Description
Comment:
If a user wants to get a cuda build variant of jaxlib, when working with Pixi it is (generally) fine to just do (on a machine with an NVIDIA GPU)
$ pixi init jax-gpu-example && cd jax-gpu-example
$ pixi workspace system-requirements add cuda 12
$ pixi add jax
$ pixi list jax
Package Version Build Size Kind Source
jax 0.7.2 pyhd8ed1ab_0 1.8 MiB conda https://conda.anaconda.org/conda-forge/
jaxlib 0.7.2 cuda129_py314h80c6225_202 166 MiB conda https://conda.anaconda.org/conda-forge/though there can be situations in which the resolution ends up with the cpu build variant, but there's no way (that I know of) to check this before install without explicitly requesting a cuda build variant of jaxlib
pixi add jax 'jaxlib[build=cuda*]'
Would it be possible to create a metapackage that would select the cuda build variant of jaxlib but at the jax level? So something similar to the pytorch-gpu metapackage?
I appreciate that maintaining jaxlib is difficult enough as is and so asking for additional features right now might not be in scope, but I was curious if this had been considered before and if there were other reasons to not want to do this.