We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8ac2969 commit 05614edCopy full SHA for 05614ed
jax/experimental/pallas/tpu.py
@@ -51,6 +51,7 @@
51
from jax._src.pallas.mosaic.primitives import semaphore_read as semaphore_read
52
from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal
53
from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait
54
+from jax._src.pallas.mosaic.random import sample_block as sample_block
55
from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key
56
57
import types
0 commit comments