Skip to content

Commit f9c16aa

Browse files
committed
python3Packages.jax: add operations to cuda test
1 parent 17b7964 commit f9c16aa

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

pkgs/development/python-modules/jax-cuda12-plugin/default.nix

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ buildPythonPackage {
9595
wheelUnpackHook
9696
];
9797

98-
# jax-cuda12-plugin looks for ptxas at runtime, e.g. with a xla custom call.
98+
# jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel.
9999
# Linking into $out is the least bad solution. See
100100
# * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101101
# * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211

pkgs/development/python-modules/jax/test-cuda.nix

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
1111
}
1212
''
1313
import jax
14+
import jax.numpy as jnp
1415
from jax import random
16+
from jax.experimental import sparse
1517
16-
assert jax.devices()[0].platform == "gpu"
18+
assert jax.devices()[0].platform == "gpu" # libcuda.so
1719
18-
rng = random.PRNGKey(0)
20+
rng = random.key(0) # libcudart.so, libcudnn.so
1921
x = random.normal(rng, (100, 100))
20-
x @ x
22+
x @ x # libcublas.so
23+
jnp.fft.fft(x) # libcufft.so
24+
jnp.linalg.inv(x) # libcusolver.so
25+
sparse.CSR.fromdense(x) @ x # libcusparse.so
2126
2227
print("success!")
2328
''

0 commit comments

Comments
 (0)