Skip to content

Commit 9bd08f0

Browse files
committed
python3Packages.jax-cuda12-plugin: patch like jax-cuda12-pjrt
1 parent 659babe commit 9bd08f0

File tree

1 file changed

+26
-3
lines changed
  • pkgs/development/python-modules/jax-cuda12-plugin

1 file changed

+26
-3
lines changed

pkgs/development/python-modules/jax-cuda12-plugin/default.nix

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
jax-cuda12-pjrt,
1313
}:
1414
let
15-
inherit (cudaPackages) cudaVersion;
1615
inherit (jaxlib) version;
16+
inherit (cudaPackages) cudaVersion;
17+
inherit (jax-cuda12-pjrt) cudaLibPath;
1718

1819
getSrcFromPypi =
1920
{
@@ -94,12 +95,34 @@ buildPythonPackage {
9495
wheelUnpackHook
9596
];
9697

98+
# jax-cuda12-plugin looks for ptxas at runtime, e.g. with a xla custom call.
99+
# Linking into $out is the least bad solution. See
100+
# * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101+
# * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
102+
# * https://github.com/NixOS/nixpkgs/pull/375186
103+
# for more info.
104+
postInstall = ''
105+
mkdir -p $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
106+
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "ptxas"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
107+
ln -s ${lib.getExe' cudaPackages.cuda_nvcc "nvlink"} $out/${python.sitePackages}/jax_cuda12_plugin/cuda/bin
108+
'';
109+
110+
# jax-cuda12-plugin contains shared libraries that open other shared libraries via dlopen
111+
# and these implicit dependencies are not recognized by ldd or
112+
# autoPatchelfHook. That means we need to sneak them into rpath. This step
113+
# must be done after autoPatchelfHook and the automatic stripping of
114+
# artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the
115+
# patchPhase.
116+
preInstallCheck = ''
117+
patchelf --add-rpath "${cudaLibPath}" $out/${python.sitePackages}/jax_cuda12_plugin/*.so
118+
'';
119+
97120
dependencies = [ jax-cuda12-pjrt ];
98121

99122
pythonImportsCheck = [ "jax_cuda12_plugin" ];
100123

101-
# no tests
102-
doCheck = false;
124+
# FIXME: there are no tests, but we need to run preInstallCheck above
125+
doCheck = true;
103126

104127
meta = {
105128
description = "JAX Plugin for CUDA12";

0 commit comments

Comments
 (0)