@@ -154,23 +154,34 @@ jax = "*"
154154[tool .pixi .feature .backends .target .win-64 .dependencies ]
155155# jax = "*" # unavailable
156156
157- # Backends that require a GPU host and a CUDA driver
158- # Note: JAX and PyTorch automatically install CUDA variants
159- # thanks to the `system-requirements` below.
157+ # Backends that require a GPU host and a CUDA driver.
158+ # Note that JAX and PyTorch automatically prefer CUDA variants
159+ # thanks to the `system-requirements` below, *if available*.
160+ # We request them explicitly below to ensure that we don't
161+ # quietly revert to CPU-only in the future, e.g. when CUDA 13
162+ # is released and CUDA 12 builds are dropped upstream.
160163[tool .pixi .feature .cuda-backends ]
161164system-requirements = { cuda = " 12" }
162165
163166[tool .pixi .feature .cuda-backends .target .linux-64 .dependencies ]
164167cupy = " *"
168+ jaxlib = { version = " *" , build = " cuda12*" }
169+ pytorch = { version = " *" , build = " cuda12*" }
165170
166171[tool .pixi .feature .cuda-backends .target .osx-64 .dependencies ]
167172# cupy = "*" # unavailable
173+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
174+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
168175
169176[tool .pixi .feature .cuda-backends .target .osx-arm64 .dependencies ]
170177# cupy = "*" # unavailable
178+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
179+ # pytorch = { version = "*", build = "cuda12*" } # unavailable
171180
172181[tool .pixi .feature .cuda-backends .target .win-64 .dependencies ]
173182cupy = " *"
183+ # jaxlib = { version = "*", build = "cuda12*" } # unavailable
184+ pytorch = { version = " *" , build = " cuda12*" }
174185
175186[tool .pixi .environments ]
176187default = { features = [" py313" ], solve-group = " py313" }
0 commit comments