|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 |
| -"""Contains Mosaic specific Pallas functions.""" |
16 |
| -from jax._src.pallas.mosaic import ANY |
17 |
| -from jax._src.pallas.mosaic import CMEM |
18 |
| -from jax._src.pallas.mosaic import PrefetchScalarGridSpec |
19 |
| -from jax._src.pallas.mosaic import SMEM |
20 |
| -from jax._src.pallas.mosaic import SemaphoreType |
21 |
| -from jax._src.pallas.mosaic import TPUMemorySpace |
22 |
| -from jax._src.pallas.mosaic import VMEM |
23 |
| -from jax._src.pallas.mosaic import DeviceIdType |
24 |
| -from jax._src.pallas.mosaic import async_copy |
25 |
| -from jax._src.pallas.mosaic import async_remote_copy |
26 |
| -from jax._src.pallas.mosaic import bitcast |
27 |
| -from jax._src.pallas.mosaic import dma_semaphore |
28 |
| -from jax._src.pallas.mosaic import delay |
29 |
| -from jax._src.pallas.mosaic import device_id |
30 |
| -from jax._src.pallas.mosaic import emit_pipeline_with_allocations |
31 |
| -from jax._src.pallas.mosaic import emit_pipeline |
32 |
| -from jax._src.pallas.mosaic import get_pipeline_schedule |
33 |
| -from jax._src.pallas.mosaic import make_pipeline_allocations |
34 |
| -from jax._src.pallas.mosaic import BufferedRef |
35 |
| -from jax._src.pallas.mosaic import encode_kernel_regeneration_metadata |
36 |
| -from jax._src.pallas.mosaic import extract_kernel_regeneration_metadata |
37 |
| -from jax._src.pallas.mosaic import get_barrier_semaphore |
38 |
| -from jax._src.pallas.mosaic import make_async_copy |
39 |
| -from jax._src.pallas.mosaic import make_async_remote_copy |
40 |
| -from jax._src.pallas.mosaic import repeat |
41 |
| -from jax._src.pallas.mosaic import roll |
42 |
| -from jax._src.pallas.mosaic import run_scoped |
43 |
| -from jax._src.pallas.mosaic import semaphore |
44 |
| -from jax._src.pallas.mosaic import semaphore_read |
45 |
| -from jax._src.pallas.mosaic import semaphore_signal |
46 |
| -from jax._src.pallas.mosaic import semaphore_wait |
| 15 | +"""Mosaic-specific Pallas APIs.""" |
| 16 | + |
| 17 | +from jax._src.pallas.mosaic import core |
| 18 | +from jax._src.pallas.mosaic.core import dma_semaphore |
| 19 | +from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec |
| 20 | +from jax._src.pallas.mosaic.core import semaphore |
| 21 | +from jax._src.pallas.mosaic.core import SemaphoreType |
| 22 | +from jax._src.pallas.mosaic.core import TPUMemorySpace |
| 23 | +from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata |
| 24 | +from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata |
| 25 | +from jax._src.pallas.mosaic.lowering import LoweringException |
| 26 | +from jax._src.pallas.mosaic.pipeline import BufferedRef |
| 27 | +from jax._src.pallas.mosaic.pipeline import emit_pipeline |
| 28 | +from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations |
| 29 | +from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule |
| 30 | +from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations |
| 31 | +from jax._src.pallas.mosaic.primitives import async_copy |
| 32 | +from jax._src.pallas.mosaic.primitives import async_remote_copy |
| 33 | +from jax._src.pallas.mosaic.primitives import bitcast |
| 34 | +from jax._src.pallas.mosaic.primitives import delay |
| 35 | +from jax._src.pallas.mosaic.primitives import device_id |
| 36 | +from jax._src.pallas.mosaic.primitives import DeviceIdType |
| 37 | +from jax._src.pallas.mosaic.primitives import get_barrier_semaphore |
| 38 | +from jax._src.pallas.mosaic.primitives import make_async_copy |
| 39 | +from jax._src.pallas.mosaic.primitives import make_async_remote_copy |
| 40 | +from jax._src.pallas.mosaic.primitives import repeat |
| 41 | +from jax._src.pallas.mosaic.primitives import roll |
| 42 | +from jax._src.pallas.mosaic.primitives import run_scoped |
| 43 | +from jax._src.pallas.mosaic.primitives import semaphore_read |
| 44 | +from jax._src.pallas.mosaic.primitives import semaphore_signal |
| 45 | +from jax._src.pallas.mosaic.primitives import semaphore_wait |
| 46 | +from jax._src.pallas.mosaic.primitives import prng_seed |
| 47 | +from jax._src.pallas.mosaic.primitives import prng_random_bits |
47 | 48 | from jax._src.tpu_custom_call import CostEstimate
|
48 |
| -from jax._src.pallas.mosaic import prng_seed |
49 |
| -from jax._src.pallas.mosaic import prng_random_bits |
| 49 | + |
| 50 | +ANY = TPUMemorySpace.ANY |
| 51 | +CMEM = TPUMemorySpace.CMEM |
| 52 | +SMEM = TPUMemorySpace.SMEM |
| 53 | +VMEM = TPUMemorySpace.VMEM |
0 commit comments