Skip to content

Commit 76ccb19

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Added some runtime type checking to copy_* and barrier_* primitives
PiperOrigin-RevId: 710302436
1 parent 7ab61b7 commit 76ccb19

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src.pallas import core as pallas_core
3434
from jax._src.pallas.mosaic_gpu import core as gpu_core
3535
from jax._src.pallas.mosaic_gpu import lowering
36+
from jax._src.pallas.mosaic_gpu.core import state_types
3637
from jax._src.state import discharge
3738
from jax._src.state import indexing
3839
from jax._src.state import primitives as state_primitives
@@ -44,13 +45,30 @@
4445
WARPGROUP_SIZE = 128
4546

4647

48+
_Ref = pallas_core.AbstractMemoryRef | state_types.TransformedRef
49+
50+
51+
def _check_ref(
52+
aval: object, name: str, memory_space: gpu_core.GPUMemorySpace
53+
) -> None:
54+
if not isinstance(aval, state_types.AbstractRef):
55+
raise TypeError(f"{name} must be a reference, got {aval}")
56+
aval_memory_space = getattr(aval, "memory_space", None) or gpu_core.GMEM
57+
if aval_memory_space is not memory_space:
58+
raise ValueError(
59+
f"{name} must be a {memory_space.name.upper()} reference, got {aval}"
60+
)
61+
62+
4763
copy_smem_to_gmem_p = jax_core.Primitive("copy_smem_to_gmem")
4864
copy_smem_to_gmem_p.multiple_results = True
4965

5066

5167
@copy_smem_to_gmem_p.def_effectful_abstract_eval
52-
def _copy_smem_to_gmem_abstract_eval(*avals, **params):
53-
del avals, params # Unused.
68+
def _copy_smem_to_gmem_abstract_eval(src, dst, *args, **params):
69+
_check_ref(src, "src", gpu_core.SMEM)
70+
_check_ref(dst, "dst", gpu_core.GMEM)
71+
del args, params # Unused.
5472
return (), {state.ReadEffect(0), state.WriteEffect(1)}
5573

5674

@@ -115,9 +133,7 @@ def _extract_smem_copy_params(transforms):
115133

116134

117135
def copy_smem_to_gmem(
118-
src: pallas_core.AbstractMemoryRef,
119-
dst: pallas_core.AbstractMemoryRef,
120-
predicate: jax.Array | None = None,
136+
src: _Ref, dst: _Ref, predicate: jax.Array | None = None
121137
) -> None:
122138
"""Asynchronously copies a SMEM reference to a GMEM reference.
123139
@@ -131,10 +147,6 @@ def copy_smem_to_gmem(
131147
:func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
132148
:func:`jax.experimental.mosaic.gpu.commit_smem`
133149
"""
134-
if src.memory_space is not gpu_core.SMEM:
135-
raise TypeError(f"src must be a SMEM reference, got {src.memory_space}")
136-
if getattr(dst, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
137-
raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}")
138150
src, src_transforms = state_primitives.get_ref_and_transforms(
139151
src, None, "copy_smem_to_gmem", force_trailing_indexer=False,
140152
)
@@ -165,8 +177,11 @@ def copy_smem_to_gmem(
165177

166178

167179
@copy_gmem_to_smem_p.def_effectful_abstract_eval
168-
def _copy_gmem_to_smem_abstract_eval(*avals, **params):
169-
del avals, params # Unused.
180+
def _copy_gmem_to_smem_abstract_eval(src, dst, barrier, *args, **params):
181+
del args, params # Unused.
182+
_check_ref(src, "src", gpu_core.GMEM)
183+
_check_ref(dst, "dst", gpu_core.SMEM)
184+
_check_ref(barrier, "barrier", gpu_core.SMEM)
170185
return (), {state.ReadEffect(0), state.WriteEffect(1)}
171186

172187

@@ -218,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
218233
return ()
219234

220235

221-
def copy_gmem_to_smem(
222-
src: pallas_core.AbstractMemoryRef,
223-
dst: pallas_core.AbstractMemoryRef,
224-
barrier: pallas_core.AbstractMemoryRef,
225-
) -> None:
236+
def copy_gmem_to_smem(src: _Ref, dst: _Ref, barrier: _Ref) -> None:
226237
"""Asynchronously copies a GMEM reference to a SMEM reference.
227238
228239
See also:
229240
:func:`jax.experimental.mosaic.gpu.barrier_arrive`
230241
:func:`jax.experimental.mosaic.gpu.barrier_wait`
231242
"""
232-
if getattr(src, "memory_space", gpu_core.GMEM) is not gpu_core.GMEM:
233-
raise TypeError(f"src must be a GMEM reference, got {src.memory_space}")
234-
if dst.memory_space is not gpu_core.SMEM:
235-
raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}")
236243
src, src_transforms = state_primitives.get_ref_and_transforms(
237244
src, None, "copy_gmem_to_smem", force_trailing_indexer=False,
238245
)
@@ -292,8 +299,9 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:
292299

293300

294301
@barrier_arrive_p.def_effectful_abstract_eval
295-
def _barrier_arrive_abstract_eval(*avals, **params):
296-
del avals, params # Unused.
302+
def _barrier_arrive_abstract_eval(barrier, *args, **params):
303+
del args, params # Unused.
304+
_check_ref(barrier, "barrier", gpu_core.SMEM)
297305
return (), {gpu_core._memory_effect}
298306

299307

@@ -329,8 +337,9 @@ def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:
329337

330338

331339
@barrier_wait_p.def_effectful_abstract_eval
332-
def _barrier_wait_abstract_eval(*avals, **params):
333-
del avals, params # Unused.
340+
def _barrier_wait_abstract_eval(barrier, *args, **params):
341+
_check_ref(barrier, "barrier", gpu_core.SMEM)
342+
del args, params # Unused.
334343
return (), {gpu_core._memory_effect}
335344

336345

0 commit comments

Comments
 (0)