@@ -418,8 +418,9 @@ class LoweringResult:
418418 module : ir .Module
419419 grid : tuple [int , ...]
420420 block : tuple [int , ...]
421- out_structs : tuple [jax .ShapeDtypeStruct , ...]
421+ new_out_shapes : tuple [jax .ShapeDtypeStruct , ...] # Does not include gmem scratch!
422422 profiler_context : ProfilerContext | None
423+ gmem_scratch_shapes : tuple [jax .ShapeDtypeStruct , ...]
423424
424425
425426@dataclasses .dataclass (frozen = True )
@@ -588,16 +589,41 @@ def ref_for_aval(aval: jax_core.AbstractValue):
588589 else :
589590 return gpu_core .SMEM (aval .shape , aval .dtype )
590591
592+ sem_placeholder = None
593+ semaphore_ref_avals = []
594+ scratch_avals = []
595+ # Need to unzip semaphores
596+ for v in jaxpr .invars [grid_mapping .slice_scratch_ops ]:
597+ aval = v .aval
598+ if (isinstance (aval , pallas_core .AbstractMemoryRef ) and
599+ jnp .issubdtype (aval .dtype , pallas_core .semaphore_dtype )):
600+ if aval .memory_space != gpu_core .GPUMemorySpace .GMEM :
601+ raise ValueError (
602+ "Only GMEM memory space is supported for semaphores in Mosaic GPU."
603+ )
604+ semaphore_ref_avals .append (aval )
605+ scratch_avals .append (sem_placeholder )
606+ else :
607+ scratch_avals .append (aval )
608+
591609 def pipeline_fn (* refs ):
592- return primitives .run_scoped (
593- functools .partial (scoped_pipeline_fn , * refs ),
610+ sem_refs = []
611+ if semaphore_ref_avals :
612+ refs , sem_refs = util .split_list (refs , [- len (semaphore_ref_avals )])
613+ primitives .run_scoped (
614+ functools .partial (scoped_pipeline_fn , * refs , sem_refs = sem_refs ),
594615 scratch_refs = [
595- ref_for_aval (v . aval )
596- for v in jaxpr . invars [ grid_mapping . slice_scratch_ops ]
616+ ref_for_aval (aval ) if aval is not sem_placeholder else aval
617+ for aval in scratch_avals
597618 ],
598619 )
620+ return () # ``wrap_init`` does not support functions returning None.
599621
600- def scoped_pipeline_fn (* refs , scratch_refs ):
622+ def scoped_pipeline_fn (* refs , sem_refs , scratch_refs ):
623+ sem_refs_it = iter (sem_refs )
624+ scratch_refs = [
625+ next (sem_refs_it ) if r is sem_placeholder else r for r in scratch_refs
626+ ]
601627 def body_fn (* refs ):
602628 grid_env = pallas_core .current_grid_env ()
603629 assert grid_env is not None # Set by ``emit_pipeline``.
@@ -628,17 +654,13 @@ def body_fn(*refs):
628654
629655 with grid_mapping .trace_env ():
630656 new_jaxpr , _ , new_consts , () = pe .trace_to_jaxpr_dynamic (
631- lu .wrap_init (
632- # ``wrap_init`` does not support functions returning None.
633- lambda * args : pipeline_fn (* args ) or (),
634- debug_info = jaxpr .debug_info ,
635- ),
657+ lu .wrap_init (pipeline_fn , debug_info = jaxpr .debug_info ),
636658 [
637659 gpu_core .GMEM (
638660 bm .array_shape_dtype .shape , bm .array_shape_dtype .dtype
639661 ).get_ref_aval ()
640662 for bm in block_mappings
641- ],
663+ ] + semaphore_ref_avals ,
642664 )
643665 assert not new_consts
644666
@@ -655,6 +677,10 @@ def body_fn(*refs):
655677 mesh .cluster if mesh is not None else (),
656678 [bm .array_shape_dtype for bm in in_block_mappings ],
657679 [bm .array_shape_dtype for bm in out_block_mappings ],
680+ [
681+ jax .ShapeDtypeStruct (r .shape , np .dtype (np .int32 ))
682+ for r in semaphore_ref_avals
683+ ],
658684 new_jaxpr ,
659685 compiler_params ,
660686 new_consts ,
@@ -668,6 +694,7 @@ def lower_jaxpr_to_module(
668694 cluster : Sequence [int ],
669695 in_shapes : Sequence [jax .ShapeDtypeStruct ],
670696 out_shapes : Sequence [jax .ShapeDtypeStruct ],
697+ gmem_scratch_shapes : Sequence [jax .ShapeDtypeStruct ],
671698 jaxpr : jax_core .Jaxpr ,
672699 compiler_params : dict [str , Any ],
673700 consts = (),
@@ -754,14 +781,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
754781 # Each range is 2 events, each event is 4 bytes.
755782 prof_spec = mgpu_profiler .ProfilerSpec (prof_space * 2 * 4 )
756783 prof_ctx = ProfilerContext (params ["profile_dir" ], prof_spec )
757- module , out_structs_gmem , _ , launch_ctx , scratch_arr = (
784+ module , new_out_shapes , _ , launch_ctx , scratch_arr = (
758785 mgpu_core ._lower_as_gpu_kernel (
759786 body ,
760787 grid = tuple (map (operator .mul , parallel_grid , cluster )),
761788 cluster = cluster ,
762789 block = block ,
763790 in_shapes = in_shapes ,
764- out_shape = out_shapes ,
791+ out_shape = ( * out_shapes , * gmem_scratch_shapes ) ,
765792 smem_scratch_shape = scratch_buffers ,
766793 module_name = mlir .sanitize_name (debug_info .func_name ),
767794 prof_spec = prof_spec ,
@@ -777,8 +804,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
777804
778805 mgpu_core ._initialize_scratch (launch_ctx , scratch_arr )
779806
807+ if gmem_scratch_shapes :
808+ new_out_shapes = new_out_shapes [:- len (gmem_scratch_shapes )]
809+
780810 return LoweringResult (
781- module , parallel_grid , block , out_structs_gmem , prof_ctx
811+ module , parallel_grid , block , new_out_shapes , prof_ctx , tuple ( gmem_scratch_shapes )
782812 )
783813
784814
0 commit comments