File tree Expand file tree Collapse file tree 2 files changed +9
-4
lines changed
pkgs/development/python-modules Expand file tree Collapse file tree 2 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -95,7 +95,7 @@ buildPythonPackage {
9595 wheelUnpackHook
9696 ] ;
9797
98- # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a xla custom call .
98+ # jax-cuda12-plugin looks for ptxas at runtime, e.g. with a triton kernel .
9999 # Linking into $out is the least bad solution. See
100100 # * https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621
101101 # * https://github.com/NixOS/nixpkgs/pull/288829#discussion_r1493852211
Original file line number Diff line number Diff line change @@ -11,13 +11,18 @@ pkgs.writers.writePython3Bin "jax-test-cuda"
1111 }
1212 ''
1313 import jax
14+ import jax.numpy as jnp
1415 from jax import random
16+ from jax.experimental import sparse
1517
16- assert jax.devices()[0].platform == "gpu"
18+ assert jax.devices()[0].platform == "gpu" # libcuda.so
1719
18- rng = random.PRNGKey (0)
20+ rng = random.key (0) # libcudart.so, libcudnn.so
1921 x = random.normal(rng, (100, 100))
20- x @ x
22+ x @ x # libcublas.so
23+ jnp.fft.fft(x) # libcufft.so
24+ jnp.linalg.inv(x) # libcusolver.so
25+ sparse.CSR.fromdense(x) @ x # libcusparse.so
2126
2227 print("success!")
2328 ''
You can’t perform that action at this time.
0 commit comments