Skip to content

Commit 39fb2a0

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add support for allocation and lowering of scratch semaphores
The semaphore arrays are allocated in GMEM and zeroed by XLA before the kernel begins. PiperOrigin-RevId: 741494241
1 parent 1c1e2e6 commit 39fb2a0

File tree

7 files changed

+117
-22
lines changed

7 files changed

+117
-22
lines changed

jax/_src/pallas/mosaic_gpu/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pytype_strict_library(
4848
"//jax:mlir",
4949
"//jax:mosaic_gpu",
5050
"//jax/_src/pallas",
51-
],
51+
] + py_deps("numpy"),
5252
)
5353

5454
pytype_strict_library(

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,25 @@ def __call__(
120120
return GPUMemoryRef(shape, dtype, memory_space=self, transforms=transforms)
121121

122122

123+
class SemaphoreType(enum.Enum):
124+
REGULAR = "regular"
125+
BARRIER = "barrier"
126+
127+
def __call__(self, shape: tuple[int, ...]):
128+
dtype: Any
129+
if self == SemaphoreType.BARRIER:
130+
dtype = pallas_core.BarrierSemaphore()
131+
else:
132+
dtype = pallas_core.Semaphore()
133+
return pallas_core.MemoryRef(shape, dtype, GPUMemorySpace.GMEM)
134+
135+
def get_array_aval(self) -> jax_core.ShapedArray:
136+
return self(()).get_array_aval()
137+
138+
def get_ref_aval(self) -> pallas_core.TransformedRef | AbstractMemoryRef:
139+
return self(()).get_ref_aval()
140+
141+
123142
def kernel(
124143
body: Callable[..., None],
125144
out_shape: object,

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,9 @@ class LoweringResult:
418418
module: ir.Module
419419
grid: tuple[int, ...]
420420
block: tuple[int, ...]
421-
out_structs: tuple[jax.ShapeDtypeStruct, ...]
421+
new_out_shapes: tuple[jax.ShapeDtypeStruct, ...] # Does not include gmem scratch!
422422
profiler_context: ProfilerContext | None
423+
gmem_scratch_shapes: tuple[jax.ShapeDtypeStruct, ...]
423424

424425

425426
@dataclasses.dataclass(frozen=True)
@@ -588,16 +589,41 @@ def ref_for_aval(aval: jax_core.AbstractValue):
588589
else:
589590
return gpu_core.SMEM(aval.shape, aval.dtype)
590591

592+
sem_placeholder = None
593+
semaphore_ref_avals = []
594+
scratch_avals = []
595+
# Need to unzip semaphores
596+
for v in jaxpr.invars[grid_mapping.slice_scratch_ops]:
597+
aval = v.aval
598+
if (isinstance(aval, pallas_core.AbstractMemoryRef) and
599+
jnp.issubdtype(aval.dtype, pallas_core.semaphore_dtype)):
600+
if aval.memory_space != gpu_core.GPUMemorySpace.GMEM:
601+
raise ValueError(
602+
"Only GMEM memory space is supported for semaphores in Mosaic GPU."
603+
)
604+
semaphore_ref_avals.append(aval)
605+
scratch_avals.append(sem_placeholder)
606+
else:
607+
scratch_avals.append(aval)
608+
591609
def pipeline_fn(*refs):
592-
return primitives.run_scoped(
593-
functools.partial(scoped_pipeline_fn, *refs),
610+
sem_refs = []
611+
if semaphore_ref_avals:
612+
refs, sem_refs = util.split_list(refs, [-len(semaphore_ref_avals)])
613+
primitives.run_scoped(
614+
functools.partial(scoped_pipeline_fn, *refs, sem_refs=sem_refs),
594615
scratch_refs=[
595-
ref_for_aval(v.aval)
596-
for v in jaxpr.invars[grid_mapping.slice_scratch_ops]
616+
ref_for_aval(aval) if aval is not sem_placeholder else aval
617+
for aval in scratch_avals
597618
],
598619
)
620+
return () # ``wrap_init`` does not support functions returning None.
599621

600-
def scoped_pipeline_fn(*refs, scratch_refs):
622+
def scoped_pipeline_fn(*refs, sem_refs, scratch_refs):
623+
sem_refs_it = iter(sem_refs)
624+
scratch_refs = [
625+
next(sem_refs_it) if r is sem_placeholder else r for r in scratch_refs
626+
]
601627
def body_fn(*refs):
602628
grid_env = pallas_core.current_grid_env()
603629
assert grid_env is not None # Set by ``emit_pipeline``.
@@ -628,17 +654,13 @@ def body_fn(*refs):
628654

629655
with grid_mapping.trace_env():
630656
new_jaxpr, _, new_consts, () = pe.trace_to_jaxpr_dynamic(
631-
lu.wrap_init(
632-
# ``wrap_init`` does not support functions returning None.
633-
lambda *args: pipeline_fn(*args) or (),
634-
debug_info=jaxpr.debug_info,
635-
),
657+
lu.wrap_init(pipeline_fn, debug_info=jaxpr.debug_info),
636658
[
637659
gpu_core.GMEM(
638660
bm.array_shape_dtype.shape, bm.array_shape_dtype.dtype
639661
).get_ref_aval()
640662
for bm in block_mappings
641-
],
663+
] + semaphore_ref_avals,
642664
)
643665
assert not new_consts
644666

@@ -655,6 +677,10 @@ def body_fn(*refs):
655677
mesh.cluster if mesh is not None else (),
656678
[bm.array_shape_dtype for bm in in_block_mappings],
657679
[bm.array_shape_dtype for bm in out_block_mappings],
680+
[
681+
jax.ShapeDtypeStruct(r.shape, np.dtype(np.int32))
682+
for r in semaphore_ref_avals
683+
],
658684
new_jaxpr,
659685
compiler_params,
660686
new_consts,
@@ -668,6 +694,7 @@ def lower_jaxpr_to_module(
668694
cluster: Sequence[int],
669695
in_shapes: Sequence[jax.ShapeDtypeStruct],
670696
out_shapes: Sequence[jax.ShapeDtypeStruct],
697+
gmem_scratch_shapes: Sequence[jax.ShapeDtypeStruct],
671698
jaxpr: jax_core.Jaxpr,
672699
compiler_params: dict[str, Any],
673700
consts=(),
@@ -754,14 +781,14 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
754781
# Each range is 2 events, each event is 4 bytes.
755782
prof_spec = mgpu_profiler.ProfilerSpec(prof_space * 2 * 4)
756783
prof_ctx = ProfilerContext(params["profile_dir"], prof_spec)
757-
module, out_structs_gmem, _, launch_ctx, scratch_arr = (
784+
module, new_out_shapes, _, launch_ctx, scratch_arr = (
758785
mgpu_core._lower_as_gpu_kernel(
759786
body,
760787
grid=tuple(map(operator.mul, parallel_grid, cluster)),
761788
cluster=cluster,
762789
block=block,
763790
in_shapes=in_shapes,
764-
out_shape=out_shapes,
791+
out_shape=(*out_shapes, *gmem_scratch_shapes),
765792
smem_scratch_shape=scratch_buffers,
766793
module_name=mlir.sanitize_name(debug_info.func_name),
767794
prof_spec=prof_spec,
@@ -777,8 +804,11 @@ def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value):
777804

778805
mgpu_core._initialize_scratch(launch_ctx, scratch_arr)
779806

807+
if gmem_scratch_shapes:
808+
new_out_shapes = new_out_shapes[:-len(gmem_scratch_shapes)]
809+
780810
return LoweringResult(
781-
module, parallel_grid, block, out_structs_gmem, prof_ctx
811+
module, parallel_grid, block, new_out_shapes, prof_ctx, tuple(gmem_scratch_shapes)
782812
)
783813

784814

jax/_src/pallas/mosaic_gpu/pallas_call_registration.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
import warnings
2424

2525
import jax
26+
from jax import lax
2627
from jax._src import core as jax_core
2728
from jax._src.interpreters import mlir
2829
from jax._src.pallas import core as pallas_core
2930
from jax._src.pallas.mosaic_gpu import lowering
3031
from jax.experimental.mosaic import gpu as mgpu
32+
import numpy as np
3133

3234

3335
def pallas_call_lowering(
@@ -74,16 +76,30 @@ def pallas_call_lowering(
7476
print(lowering_result.module.operation)
7577

7678
module = lowering_result.module
77-
new_avals_out = [
78-
jax_core.ShapedArray(t.shape, t.dtype) for t in lowering_result.out_structs
79-
]
79+
new_avals_in = list(ctx.avals_in)
80+
new_avals_out = list(map(_as_shaped_array, lowering_result.new_out_shapes))
81+
scratch_args = ()
82+
if lowering_result.gmem_scratch_shapes:
83+
input_output_aliases += tuple(
84+
(len(new_avals_in) + i, len(new_avals_out) + i)
85+
for i in range(len(lowering_result.gmem_scratch_shapes))
86+
)
87+
new_avals_in.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes))
88+
new_avals_out.extend(map(_as_shaped_array, lowering_result.gmem_scratch_shapes))
89+
def zero_init_gmem_scratch():
90+
return [lax.zeros_like_array(s) for s in lowering_result.gmem_scratch_shapes]
91+
scratch_args = mlir.lower_fun(
92+
zero_init_gmem_scratch, multiple_results=True
93+
)(ctx.replace(avals_in=()))
8094
outs = mgpu.core._mosaic_gpu_lowering_rule(
81-
ctx.replace(avals_out=new_avals_out),
82-
*args,
95+
ctx.replace(avals_in=new_avals_in, avals_out=new_avals_out),
96+
*args, *scratch_args,
8397
module=module,
84-
out_types=lowering_result.out_structs,
98+
out_types=(*lowering_result.new_out_shapes, *lowering_result.gmem_scratch_shapes),
8599
input_output_aliases=input_output_aliases,
86100
)
101+
if lowering_result.gmem_scratch_shapes: # Drop the GMEM scratch.
102+
outs = outs[:-len(lowering_result.gmem_scratch_shapes)]
87103
if (prof_ctx := lowering_result.profiler_context) is not None:
88104
*outs, prof_buffer = outs
89105
if (dump_path := prof_ctx.dump_path) == "sponge":
@@ -112,3 +128,7 @@ def do_callback(prof_buffer):
112128
ctx.replace(avals_in=(new_avals_out[-1],)), prof_buffer
113129
)
114130
return outs
131+
132+
133+
def _as_shaped_array(t: jax.ShapeDtypeStruct) -> jax_core.ShapedArray:
134+
return jax_core.ShapedArray(t.shape, np.dtype(t.dtype))

jax/experimental/mosaic/gpu/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _mosaic_gpu_lowering_rule(
104104
out_types,
105105
input_output_aliases: tuple[tuple[int, int], ...] = (),
106106
):
107+
assert len(args) == len(ctx.avals_in)
107108
assert len(out_types) == len(ctx.avals_out)
108109
module = _run_serde_pass(
109110
module,

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace as GPUMemorySpace
2424
from jax._src.pallas.mosaic_gpu.core import GPUMesh as GPUMesh
2525
from jax._src.pallas.mosaic_gpu.core import kernel as kernel
26+
from jax._src.pallas.mosaic_gpu.core import SemaphoreType as SemaphoreType
2627
from jax._src.pallas.mosaic_gpu.core import SwizzleTransform as SwizzleTransform
2728
from jax._src.pallas.mosaic_gpu.core import TilingTransform as TilingTransform
2829
from jax._src.pallas.mosaic_gpu.core import transpose_ref as transpose_ref

tests/pallas/mosaic_gpu_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,30 @@ def compute(l_smem, r_smem, o_smem):
24082408
out = plgpu.kernel(body, out_shape=x, grid=(2,), axis_names=("rows",))(x, x)
24092409
np.testing.assert_allclose(out, x + x)
24102410

2411+
def test_semaphore_lowering(self):
2412+
# This is a smoke test until we add support for lowering of semaphore ops.
2413+
def body(i_ref1, i_ref2, o_ref, sem_ref):
2414+
del i_ref2 # Only here to have a different number of inputs and outputs.
2415+
assert sem_ref.shape == (4,)
2416+
assert jnp.issubdtype(sem_ref.dtype, pl.semaphore)
2417+
o_ref[...] = i_ref1[...]
2418+
x = jnp.arange(128, dtype=jnp.float32).reshape((128,))
2419+
kernel = pl.pallas_call(
2420+
body, out_shape=x, scratch_shapes=[plgpu.SemaphoreType.REGULAR((4,))],
2421+
)
2422+
text = jax.jit(kernel).lower(x, x).as_text()
2423+
self.assertIn(
2424+
r"output_operand_aliases ="
2425+
r" [#stablehlo.output_operand_alias<output_tuple_indices = [1],"
2426+
r" operand_index = 2, operand_tuple_indices = []>]",
2427+
text,
2428+
)
2429+
self.assertIn(
2430+
r"(tensor<128xf32>, tensor<128xf32>, tensor<4xi32>) ->"
2431+
r" (tensor<128xf32>, tensor<4xi32>)",
2432+
text,
2433+
)
2434+
24112435

24122436
class ExamplesSm90ATest(PallasSm90ATest):
24132437

0 commit comments

Comments
 (0)