Skip to content

Commit 59f337a

Browse files
authored
python3Packages.jax: add missing cuda libraries (#375186)
2 parents 3db6580 + f9c16aa commit 59f337a

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

pkgs/development/python-modules/jax-cuda12-pjrt/default.nix

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@ let
1818
cudaLibPath = lib.makeLibraryPath (
1919
with cudaPackages;
2020
[
21+
(lib.getLib libcublas) # libcublas.so
22+
(lib.getLib cuda_cupti) # libcupti.so
2123
(lib.getLib cuda_cudart) # libcudart.so
2224
(lib.getLib cudnn) # libcudnn.so
23-
(lib.getLib libcublas) # libcublas.so
24-
addDriverRunpath.driverLink # libcuda.so
25+
(lib.getLib libcufft) # libcufft.so
26+
(lib.getLib libcusolver) # libcusolver.so
27+
(lib.getLib libcusparse) # libcusparse.so
28+
(lib.getLib nccl) # libnccl.so
29+
(lib.getLib libnvjitlink) # libnvJitLink.so
30+
(lib.getLib addDriverRunpath.driverLink) # libcuda.so
2531
]
2632
);
2733

@@ -83,6 +89,8 @@ buildPythonPackage {
8389

8490
pythonImportsCheck = [ "jax_plugins" ];
8591

92+
inherit cudaLibPath;
93+
8694
meta = {
8795
description = "JAX XLA PJRT Plugin for NVIDIA GPUs";
8896
homepage = "https://github.com/jax-ml/jax/tree/main/jax_plugins/cuda";

pkgs/development/python-modules/jax-cuda12-plugin/default.nix

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
jax-cuda12-pjrt,
1313
}:
1414
let
15-
inherit (cudaPackages) cudaVersion;
1615
inherit (jaxlib) version;
16+
inherit (cudaPackages) cudaVersion;
17+
inherit (jax-cuda12-pjrt) cudaLibPath;
1718

1819
getSrcFromPypi =
1920
{
@@ -94,12 +95,34 @@ buildPythonPackage {
9495
wheelUnpackHook
9596
];
9697

98+
# jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
99+
# Linking into $out is the least bad solution. See
100+
# * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101+
# * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
102+
# * https://github.com/NixOS/nixpkgs/pull/375186
103+
# for more info.
104+
postInstall = ''
105+
mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106+
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107+
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
108+
'';
109+
110+
# jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
111+
# and these implicit dependencies are not recognized by ldd or
112+
# autoPatchelfHook. That means we need to sneak them into rpath. This step
113+
# must be done after autoPatchelfHook and the automatic stripping of
114+
# artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
115+
# patchPhase.
116+
preInstallCheck = ''
117+
patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
118+
'';
119+
97120
dependencies = [ jax-cuda12-pjrt ];
98121

99122
pythonImportsCheck = [ "jax_cuda12_plugin" ];
100123

101-
# no tests
102-
doCheck = false;
124+
# FIXME: there are no tests, but we need to run preInstallCheck above
125+
doCheck = true;
103126

104127
meta = {
105128
description = "JAX Plugin for CUDA12";

pkgs/development/python-modules/jax/default.nix

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,6 @@ buildPythonPackage rec {
198198

199199
meta = {
200200
description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
201-
longDescription = ''
202-
This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
203-
e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
204-
'';
205201
homepage = "https://github.com/google/jax";
206202
license = lib.licenses.asl20;
207203
maintainers = with lib.maintainers; [ samuela ];

pkgs/development/python-modules/jax/test-cuda.nix

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
1111
}
1212
''
1313
import jax
14+
import jax.numpy as jnp
1415
from jax import random
16+
from jax.experimental import sparse
1517
16-
assert jax.devices()[0].platform == "gpu"
18+
assert jax.devices()[0].platform == "gpu" # libcuda.so
1719
18-
rng = random.PRNGKey(0)
20+
rng = random.key(0) # libcudart.so, libcudnn.so
1921
x = random.normal(rng, (100, 100))
20-
x @ x
22+
x @ x # libcublas.so
23+
jnp.fft.fft(x) # libcufft.so
24+
jnp.linalg.inv(x) # libcusolver.so
25+
sparse.CSR.fromdense(x) @ x # libcusparse.so
2126
2227
print("success!")
2328
''

0 commit comments

Comments
 (0)