|
12 | 12 | jax-cuda12-pjrt, |
13 | 13 | }: |
14 | 14 | let |
15 | | - inherit (cudaPackages) cudaVersion; |
16 | 15 | inherit (jaxlib) version; |
| 16 | + inherit (cudaPackages) cudaVersion; |
| 17 | + inherit (jax-cuda12-pjrt) cudaLibPath; |
17 | 18 |
|
18 | 19 | getSrcFromPypi = |
19 | 20 | { |
@@ -94,12 +95,34 @@ buildPythonPackage { |
94 | 95 | wheelUnpackHook |
95 | 96 | ]; |
96 | 97 |
|
| 98 | + # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a xla custom call. |
| 99 | + # Linking into $out is the least bad solution. See |
| 100 | + # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 |
| 101 | + # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211 |
| 102 | + # * https://github.com/NixOS/nixpkgs/pull/375186 |
| 103 | + # for more info. |
| 104 | + postInstall = '' |
| 105 | + mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin |
| 106 | + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin |
| 107 | + ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin |
| 108 | + ''; |
| 109 | + |
| 110 | + # jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen |
| 111 | + # and these implicit dependencies are not recognized by ldd or |
| 112 | + # autoPatchelfHook. That means we need to sneak them into rpath. This step |
| 113 | + # must be done after autoPatchelfHook and the automatic stripping of |
| 114 | + # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the |
| 115 | + # patchPhase. |
| 116 | + preInstallCheck = '' |
| 117 | + patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so |
| 118 | + ''; |
| 119 | + |
97 | 120 | dependencies = [ jax-cuda12-pjrt ]; |
98 | 121 |
|
99 | 122 | pythonImportsCheck = [ "jax_cuda12_plugin" ]; |
100 | 123 |
|
101 | | - # no tests |
102 | | - doCheck = false; |
| 124 | + # FIXME: there are no tests, but we need to run preInstallCheck above |
| 125 | + doCheck = true; |
103 | 126 |
|
104 | 127 | meta = { |
105 | 128 | description = "JAX Plugin for CUDA12"; |
|
0 commit comments