diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index 996d4ba05893..5ae3761ec9fd 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -158,11 +158,18 @@ def __add__(self, other: Resources) -> Resources: scoped_gmem_semaphores=scoped_gmem_semaphores, ) - def __or__(self, other: Resources) -> Resources: + def or_(self, other: Resources, axis_names: _AxisNames) -> Resources: sems = self.scoped_gmem_semaphores other_sems = other.scoped_gmem_semaphores - scoped_gmem_semaphores = {key: max(sems.get(key, 0), other_sems.get(key, 0)) - for key in sems.keys() | other_sems.keys()} + scoped_gmem_semaphores = {} + for sem_scope in sems.keys() | other_sems.keys(): + if _is_block_local_scope(sem_scope, axis_names): + value = max(sems.get(sem_scope, 0), other_sems.get(sem_scope, 0)) + elif _is_global_scope(sem_scope, axis_names): + value = sems.get(sem_scope, 0) + other_sems.get(sem_scope, 0) + else: + raise RuntimeError(f"Unrecognized semaphore scope: {sem_scope}") + scoped_gmem_semaphores[sem_scope] = value return Resources( smem_scratch_bytes=max( self.smem_scratch_bytes, other.smem_scratch_bytes @@ -204,7 +211,10 @@ def _estimate_resources( for eqn in jaxpr.eqns: # TODO(slebedev): Add support for other primitives, notably control flow. if rule := _resource_estimators.get(eqn.primitive): - rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params) + rs = rs.or_( + rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params), + ctx.axis_names, + ) continue # Assume that unsupported primitives are neutral wrt resource usage, # unless they have a jaxpr in their params. @@ -225,7 +235,7 @@ def _cond_resource_estimator( ) -> Resources: del args # Unused. return functools.reduce( - lambda a, b: a | b, + lambda a, b: a.or_(b, ctx.axis_names), (_estimate_resources(ctx, branch.jaxpr) for branch in branches), ) @@ -247,8 +257,8 @@ def _while_resource_estimator( **params, ) -> Resources: del args, params # Unused. - return _estimate_resources(ctx, cond_jaxpr.jaxpr) | _estimate_resources( - ctx, body_jaxpr.jaxpr + return _estimate_resources(ctx, cond_jaxpr.jaxpr).or_( + _estimate_resources(ctx, body_jaxpr.jaxpr), ctx.axis_names ) @@ -338,7 +348,13 @@ def _run_scoped_resource_estimator( # Don't need to allocate anything. pass elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore): - rs += Resources(scoped_gmem_semaphores={collective_axes: aval.size}) + if _is_block_local_scope(collective_axes, ctx.axis_names): + rs += Resources(scoped_gmem_semaphores={collective_axes: aval.size}) + else: + raise ValueError( + "Only thread-collective allocations are supported in run_scoped. To" + " allocate global semaphores, use pl.get_global." + ) else: raise NotImplementedError( f"Unsupported memory space: {aval.memory_space}") @@ -2567,6 +2583,10 @@ def _run_scoped_lowering_rule( # Make sure everyone has exited previous scoped allocations. Note that we # don't synchronize when we exit the allocation, but only when we might want # to reuse its memory again. + if collective_axes and collective_axes != (wg_axis,): + raise ValueError( + "Only thread-collective allocations are supported in run_scoped." + ) if is_multithreaded and is_thread_collective: gpu_dialect.barrier() with contextlib.ExitStack() as alloc_stack: @@ -2705,6 +2725,33 @@ def _run_scoped_lowering_rule( return outs +@_register_resource_estimator(primitives.get_global_p) +def _get_global_resource_estimator( + ctx: ResourceEstimatorContext, *, what +) -> Resources: + if what.memory_space == gpu_core.GMEM and jnp.issubdtype( + what.dtype, pallas_core.semaphore + ): + collective_axes = tuple(ctx.axis_names) + return Resources(scoped_gmem_semaphores={collective_axes: what.size}) + raise NotImplementedError(f"get_global only supports semaphores, got {what}") + + +@register_lowering_rule(primitives.get_global_p, mgpu.LoweringSemantics.Lane) +@register_lowering_rule( + primitives.get_global_p, mgpu.LoweringSemantics.Warpgroup +) +def _get_global_lowering_rule(ctx: LoweringRuleContext, *, what): + if what.memory_space == gpu_core.GMEM and jnp.issubdtype( + what.dtype, pallas_core.semaphore + ): + collective_axes = tuple(ctx.module_ctx.axis_names) + return ctx.module_ctx.reserve_semaphores( + what.shape, collective_axes=collective_axes + ).__enter__() + raise NotImplementedError(f"get_global only supports semaphores, got {what}") + + @register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane) @register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup) def _run_state_lowering_rule( @@ -3386,9 +3433,14 @@ def _semaphore_read_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): raise NotImplementedError(f"Unhandled transforms for semaphore_read: {transforms}") sem_ptr = mgpu.utils.memref_ptr(sem) i32_ty = ir.IntegerType.get_signless(32) - return llvm_dialect.inline_asm( - i32_ty, [sem_ptr], "ld.acquire.sys.u32 $0,[$1];", "=r,l", has_side_effects=True, + result = llvm_dialect.inline_asm( + i32_ty, + [sem_ptr], + "ld.acquire.sys.u32 $0,[$1];", + "=r,l", + has_side_effects=True, ) + return _ensure_fa(result, jnp.int32) @register_lowering_rule(primitives.semaphore_signal_p, mgpu.LoweringSemantics.Lane) diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index e9aa7be6c008..129b22cd7429 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -966,6 +966,47 @@ def _lower_fun(*lower_fun_args): return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) +get_global_p = jax_core.Primitive("get_global") +get_global_p.multiple_results = False + + +def get_global(what: pallas_core.ScratchShape) -> jax.Array: + """Returns a global reference that persists across all kernel invocations. + + Each call to get_global returns a different and unique reference, but one that + is stable across invocations of the kernel body. + + Args: + what: The reference type to allocate. Each backend has its own set of + reference types (e.g., `plgpu.SemaphoreType.REGULAR` for GPU). + + Example:: + + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_ref) + pl.semaphore_wait(sem_ref) + """ + ref_aval = what.get_ref_aval() + return get_global_p.bind(what=ref_aval) + + +@get_global_p.def_abstract_eval +def _get_global_abstract_eval(*, what): + return what + + +def _get_global_discharge_rule(in_avals, out_avals, *, what): + del in_avals, out_avals, what + raise NotImplementedError( + "get_global discharge is not supported in interpret mode." + ) + + +state_discharge.register_discharge_rule(get_global_p)( + _get_global_discharge_rule +) + + def _get_ref_and_transforms(ref): if isinstance(ref, state.TransformedRef): return ref.ref, ref.transforms diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index b8db4e430ec7..76d2807a2999 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -29,8 +29,8 @@ from jax._src.pallas.core import Element as Element from jax._src.pallas.core import GridSpec as GridSpec from jax._src.pallas.core import lower_as_mlir as lower_as_mlir -from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import MemoryRef as MemoryRef +from jax._src.pallas.core import MemorySpace as MemorySpace from jax._src.pallas.core import no_block_spec as no_block_spec from jax._src.pallas.core import semaphore as semaphore from jax._src.pallas.core import Squeezed as Squeezed @@ -57,6 +57,7 @@ from jax._src.pallas.primitives import debug_print as debug_print from jax._src.pallas.primitives import DeviceIdType as DeviceIdType from jax._src.pallas.primitives import dot as dot +from jax._src.pallas.primitives import get_global as get_global from jax._src.pallas.primitives import load as _deprecated_load from jax._src.pallas.primitives import max_contiguous as max_contiguous from jax._src.pallas.primitives import multiple_of as multiple_of diff --git a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py index d810fa8512e1..a04862541b66 100644 --- a/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py @@ -74,7 +74,8 @@ def all_gather_lhs_matmul( num_sms = jax.devices()[0].core_count # 132 for H100 SXM GPUs. - def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem): + def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem): + received_sem = pl.get_global(plgpu.SemaphoreType.REGULAR) wg_idx = lax.axis_index("wg") dev_id = lax.axis_index(axis_name) send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size) @@ -145,13 +146,6 @@ def _device_loop(device_offset): # Make sure all copies are fully done. plgpu.wait_smem_to_gmem(0, wait_read_only=True) - def kernel_entry(*args): - return pl.run_scoped( - functools.partial(kernel_body, *args), - received_sem=plgpu.SemaphoreType.REGULAR, - collective_axes=("cluster_grid", "cluster", "wg"), - ) - num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n)) out_swizzle = plgpu.find_swizzle(epi_tile_n * jnp.dtype(dtype).itemsize * 8) out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize @@ -160,7 +154,7 @@ def kernel_entry(*args): plgpu.SwizzleTransform(out_swizzle), ) result, _ = plgpu.kernel( - kernel_entry, + kernel_body, out_shape=[ # The output, with its M dimension all-gathered. jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype), diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 10286ff4e17c..405480abcb80 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -5667,18 +5667,16 @@ def test_global_semaphore(self): # We signal from block 0 and wait on block 1 to test whether the semaphore # is globally shared. def body(out_ref): - @functools.partial(pl.run_scoped, - sem_ref=plgpu.SemaphoreType.REGULAR, - collective_axes="x") - def _scoped(sem_ref): - block_id = lax.axis_index("x") - @pl.when(block_id == 0) - def _(): - pl.semaphore_signal(sem_ref) - @pl.when(block_id == 1) - def _(): - pl.semaphore_wait(sem_ref) - out_ref[...] = jnp.ones_like(out_ref) + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem_ref) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem_ref) + out_ref[...] = jnp.ones_like(out_ref) + kernel = self.kernel( body, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), @@ -5690,18 +5688,16 @@ def _(): def test_global_semaphore_with_multiple_threads(self): def body(out_ref): - @functools.partial(pl.run_scoped, - sem_ref=plgpu.SemaphoreType.REGULAR, - collective_axes=("x", "wg")) - def _scoped(sem_ref): - block_id = lax.axis_index("x") - @pl.when(block_id == 0) - def _(): - pl.semaphore_signal(sem_ref) - @pl.when(block_id == 1) - def _(): - pl.semaphore_wait(sem_ref) - out_ref[...] = jnp.ones_like(out_ref) + sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem_ref) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem_ref) + out_ref[...] = jnp.ones_like(out_ref) + kernel = self.kernel( body, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), @@ -5713,27 +5709,82 @@ def _(): result = kernel() np.testing.assert_array_equal(result, jnp.ones((128,), jnp.float32)) + def test_multiple_get_global_semaphores(self): + def body(out_ref): + sem1 = pl.get_global(plgpu.SemaphoreType.REGULAR) + sem2 = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(sem1, inc=5) + pl.semaphore_signal(sem2, inc=10) + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(sem1, value=5, decrement=False) + pl.semaphore_wait(sem2, value=10, decrement=False) + val1 = pl.semaphore_read(sem1) + val2 = pl.semaphore_read(sem2) + out_ref[0] = val1 + out_ref[1] = val2 + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), + grid=(10,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.array([5, 10], jnp.int32)) + + def test_get_global_in_and_outside_control_flow(self): + def body(out_ref): + sem_before = pl.get_global(plgpu.SemaphoreType.REGULAR) + block_id = lax.axis_index("x") + + @pl.when(block_id == 0) + def _(): + sem_inside = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_inside, 7) + pl.semaphore_signal(sem_before, 3) + val_inside = pl.semaphore_read(sem_inside) + out_ref[1] = val_inside + + sem_after = pl.get_global(plgpu.SemaphoreType.REGULAR) + pl.semaphore_signal(sem_after, 11) + val_before = pl.semaphore_read(sem_before) + val_after = pl.semaphore_read(sem_after) + out_ref[0] = val_before + out_ref[2] = val_after + + kernel = self.kernel( + body, + out_shape=jax.ShapeDtypeStruct((3,), jnp.int32), + grid=(1,), + grid_names=("x",), + ) + result = kernel() + np.testing.assert_array_equal(result, jnp.array([3, 7, 11], jnp.int32)) + def test_multiple_semaphore_scopes(self): def body(out_ref): - # Allocate a global-scoped semaphore. - @functools.partial(pl.run_scoped, - global_sem=plgpu.SemaphoreType.REGULAR, - collective_axes="x") - def _scope1(global_sem): - # Allocate a block-scoped semaphore. - @functools.partial(pl.run_scoped, - block_sem=plgpu.SemaphoreType.REGULAR) - def _scope2(block_sem): - block_id = lax.axis_index("x") - pl.semaphore_signal(block_sem) - @pl.when(block_id == 0) - def _(): - pl.semaphore_signal(global_sem) - @pl.when(block_id == 1) - def _(): - pl.semaphore_wait(global_sem) - out_ref[...] = jnp.ones_like(out_ref) - pl.semaphore_wait(block_sem) + global_sem = pl.get_global(plgpu.SemaphoreType.REGULAR) + + @functools.partial(pl.run_scoped, block_sem=plgpu.SemaphoreType.REGULAR) + def _scope2(block_sem): + block_id = lax.axis_index("x") + pl.semaphore_signal(block_sem) + + @pl.when(block_id == 0) + def _(): + pl.semaphore_signal(global_sem) + + @pl.when(block_id == 1) + def _(): + pl.semaphore_wait(global_sem) + out_ref[...] = jnp.ones_like(out_ref) + + pl.semaphore_wait(block_sem) + kernel = self.kernel( body, out_shape=jax.ShapeDtypeStruct((128,), jnp.float32), @@ -5950,22 +6001,22 @@ def test_dynamic_work_scheduling_with_carry(self): sm_count = jax.devices()[0].core_count def body(out_gmem, _): sm_idx = lax.axis_index("x") - - @functools.partial(pl.run_scoped, - global_semaphore=plgpu.SemaphoreType.REGULAR, - collective_axes="x") - def _scoped(global_semaphore): - @pl.when(sm_idx == 0) - def _steal_loop(): - def loop_body(loop_info: plgpu.NDLoopInfo, carry: jax.Array): - del loop_info - return carry + jnp.int32(1) - final_carry = plgpu.dynamic_scheduling_loop( - ("x",), init_carry=jnp.int32(0))(loop_body) - out_gmem[0] = final_carry - pl.semaphore_signal(global_semaphore, inc=sm_count) - # All SMs wait until SM 0 has finished all blocks. - pl.semaphore_wait(global_semaphore) + global_semaphore = pl.get_global(plgpu.SemaphoreType.REGULAR) + + @pl.when(sm_idx == 0) + def _steal_loop(): + def loop_body(loop_info: plgpu.NDLoopInfo, carry: jax.Array): + del loop_info + return carry + jnp.int32(1) + + final_carry = plgpu.dynamic_scheduling_loop( + ("x",), init_carry=jnp.int32(0) + )(loop_body) + out_gmem[0] = final_carry + pl.semaphore_signal(global_semaphore, inc=sm_count) + + # All SMs wait until SM 0 has finished all blocks. + pl.semaphore_wait(global_semaphore) result = self.kernel(body, out_shape=jax.ShapeDtypeStruct((1,), jnp.int32),