We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b1b1826 commit 3abb55cCopy full SHA for 3abb55c
scripts/test-jax-install.py
@@ -0,0 +1,21 @@
1
+import jax
2
+import jax.numpy as jnp
3
+
4
+devices = jax.devices()
5
+print(f"The available devices are: {devices}")
6
7
+@jax.jit
8
+def matrix_multiply(a, b):
9
+ return jnp.dot(a, b)
10
11
+# Example usage:
12
+key = jax.random.PRNGKey(0)
13
+x = jax.random.normal(key, (1000, 1000))
14
+y = jax.random.normal(key, (1000, 1000))
15
+z = matrix_multiply(x, y)
16
17
+# Now the function is JIT compiled and will likely run on GPU (if available)
18
+print(z)
19
20
21
0 commit comments