3333from jax ._src .pallas import core as pallas_core
3434from jax ._src .pallas .mosaic_gpu import core as gpu_core
3535from jax ._src .pallas .mosaic_gpu import lowering
36+ from jax ._src .pallas .mosaic_gpu .core import state_types
3637from jax ._src .state import discharge
3738from jax ._src .state import indexing
3839from jax ._src .state import primitives as state_primitives
4445WARPGROUP_SIZE = 128
4546
4647
48+ _Ref = pallas_core .AbstractMemoryRef | state_types .TransformedRef
49+
50+
51+ def _check_ref (
52+ aval : object , name : str , memory_space : gpu_core .GPUMemorySpace
53+ ) -> None :
54+ if not isinstance (aval , state_types .AbstractRef ):
55+ raise TypeError (f"{ name } must be a reference, got { aval } " )
56+ aval_memory_space = getattr (aval , "memory_space" , None ) or gpu_core .GMEM
57+ if aval_memory_space is not memory_space :
58+ raise ValueError (
59+ f"{ name } must be a { memory_space .name .upper ()} reference, got { aval } "
60+ )
61+
62+
4763copy_smem_to_gmem_p = jax_core .Primitive ("copy_smem_to_gmem" )
4864copy_smem_to_gmem_p .multiple_results = True
4965
5066
5167@copy_smem_to_gmem_p .def_effectful_abstract_eval
52- def _copy_smem_to_gmem_abstract_eval (* avals , ** params ):
53- del avals , params # Unused.
68+ def _copy_smem_to_gmem_abstract_eval (src , dst , * args , ** params ):
69+ _check_ref (src , "src" , gpu_core .SMEM )
70+ _check_ref (dst , "dst" , gpu_core .GMEM )
71+ del args , params # Unused.
5472 return (), {state .ReadEffect (0 ), state .WriteEffect (1 )}
5573
5674
@@ -115,9 +133,7 @@ def _extract_smem_copy_params(transforms):
115133
116134
117135def copy_smem_to_gmem (
118- src : pallas_core .AbstractMemoryRef ,
119- dst : pallas_core .AbstractMemoryRef ,
120- predicate : jax .Array | None = None ,
136+ src : _Ref , dst : _Ref , predicate : jax .Array | None = None
121137) -> None :
122138 """Asynchronously copies a SMEM reference to a GMEM reference.
123139
@@ -131,10 +147,6 @@ def copy_smem_to_gmem(
131147 :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem`
132148 :func:`jax.experimental.mosaic.gpu.commit_smem`
133149 """
134- if src .memory_space is not gpu_core .SMEM :
135- raise TypeError (f"src must be a SMEM reference, got { src .memory_space } " )
136- if getattr (dst , "memory_space" , gpu_core .GMEM ) is not gpu_core .GMEM :
137- raise ValueError (f"dst must be a GMEM reference, got { dst .memory_space } " )
138150 src , src_transforms = state_primitives .get_ref_and_transforms (
139151 src , None , "copy_smem_to_gmem" , force_trailing_indexer = False ,
140152 )
@@ -165,8 +177,11 @@ def copy_smem_to_gmem(
165177
166178
167179@copy_gmem_to_smem_p .def_effectful_abstract_eval
168- def _copy_gmem_to_smem_abstract_eval (* avals , ** params ):
169- del avals , params # Unused.
180+ def _copy_gmem_to_smem_abstract_eval (src , dst , barrier , * args , ** params ):
181+ del args , params # Unused.
182+ _check_ref (src , "src" , gpu_core .GMEM )
183+ _check_ref (dst , "dst" , gpu_core .SMEM )
184+ _check_ref (barrier , "barrier" , gpu_core .SMEM )
170185 return (), {state .ReadEffect (0 ), state .WriteEffect (1 )}
171186
172187
@@ -218,21 +233,13 @@ def _copy_gmem_to_smem_lowering(
218233 return ()
219234
220235
221- def copy_gmem_to_smem (
222- src : pallas_core .AbstractMemoryRef ,
223- dst : pallas_core .AbstractMemoryRef ,
224- barrier : pallas_core .AbstractMemoryRef ,
225- ) -> None :
236+ def copy_gmem_to_smem (src : _Ref , dst : _Ref , barrier : _Ref ) -> None :
226237 """Asynchronously copies a GMEM reference to a SMEM reference.
227238
228239 See also:
229240 :func:`jax.experimental.mosaic.gpu.barrier_arrive`
230241 :func:`jax.experimental.mosaic.gpu.barrier_wait`
231242 """
232- if getattr (src , "memory_space" , gpu_core .GMEM ) is not gpu_core .GMEM :
233- raise TypeError (f"src must be a GMEM reference, got { src .memory_space } " )
234- if dst .memory_space is not gpu_core .SMEM :
235- raise ValueError (f"dst must be a SMEM reference, got { dst .memory_space } " )
236243 src , src_transforms = state_primitives .get_ref_and_transforms (
237244 src , None , "copy_gmem_to_smem" , force_trailing_indexer = False ,
238245 )
@@ -292,8 +299,9 @@ def _extract_barrier_indexer(transforms) -> indexing.NDIndexer | None:
292299
293300
294301@barrier_arrive_p .def_effectful_abstract_eval
295- def _barrier_arrive_abstract_eval (* avals , ** params ):
296- del avals , params # Unused.
302+ def _barrier_arrive_abstract_eval (barrier , * args , ** params ):
303+ del args , params # Unused.
304+ _check_ref (barrier , "barrier" , gpu_core .SMEM )
297305 return (), {gpu_core ._memory_effect }
298306
299307
@@ -329,8 +337,9 @@ def barrier_arrive(barrier: pallas_core.AbstractMemoryRef) -> None:
329337
330338
331339@barrier_wait_p .def_effectful_abstract_eval
332- def _barrier_wait_abstract_eval (* avals , ** params ):
333- del avals , params # Unused.
340+ def _barrier_wait_abstract_eval (barrier , * args , ** params ):
341+ _check_ref (barrier , "barrier" , gpu_core .SMEM )
342+ del args , params # Unused.
334343 return (), {gpu_core ._memory_effect }
335344
336345
0 commit comments