Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 62 additions & 10 deletions jax/_src/pallas/mosaic_gpu/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,18 @@ def __add__(self, other: Resources) -> Resources:
scoped_gmem_semaphores=scoped_gmem_semaphores,
)

def __or__(self, other: Resources) -> Resources:
def or_(self, other: Resources, axis_names: _AxisNames) -> Resources:
sems = self.scoped_gmem_semaphores
other_sems = other.scoped_gmem_semaphores
scoped_gmem_semaphores = {key: max(sems.get(key, 0), other_sems.get(key, 0))
for key in sems.keys() | other_sems.keys()}
scoped_gmem_semaphores = {}
for sem_scope in sems.keys() | other_sems.keys():
if _is_block_local_scope(sem_scope, axis_names):
value = max(sems.get(sem_scope, 0), other_sems.get(sem_scope, 0))
elif _is_global_scope(sem_scope, axis_names):
value = sems.get(sem_scope, 0) + other_sems.get(sem_scope, 0)
else:
raise RuntimeError(f"Unrecognized semaphore scope: {sem_scope}")
scoped_gmem_semaphores[sem_scope] = value
return Resources(
smem_scratch_bytes=max(
self.smem_scratch_bytes, other.smem_scratch_bytes
Expand Down Expand Up @@ -204,7 +211,10 @@ def _estimate_resources(
for eqn in jaxpr.eqns:
# TODO(slebedev): Add support for other primitives, notably control flow.
if rule := _resource_estimators.get(eqn.primitive):
rs |= rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params)
rs = rs.or_(
rule(ctx, *(invar.aval for invar in eqn.invars), **eqn.params),
ctx.axis_names,
)
continue
# Assume that unsupported primitives are neutral wrt resource usage,
# unless they have a jaxpr in their params.
Expand All @@ -225,7 +235,7 @@ def _cond_resource_estimator(
) -> Resources:
del args # Unused.
return functools.reduce(
lambda a, b: a | b,
lambda a, b: a.or_(b, ctx.axis_names),
(_estimate_resources(ctx, branch.jaxpr) for branch in branches),
)

Expand All @@ -247,8 +257,8 @@ def _while_resource_estimator(
**params,
) -> Resources:
del args, params # Unused.
return _estimate_resources(ctx, cond_jaxpr.jaxpr) | _estimate_resources(
ctx, body_jaxpr.jaxpr
return _estimate_resources(ctx, cond_jaxpr.jaxpr).or_(
_estimate_resources(ctx, body_jaxpr.jaxpr), ctx.axis_names
)


Expand Down Expand Up @@ -338,7 +348,13 @@ def _run_scoped_resource_estimator(
# Don't need to allocate anything.
pass
elif aval.memory_space == gpu_core.GMEM and jnp.issubdtype(aval.dtype, pallas_core.semaphore):
rs += Resources(scoped_gmem_semaphores={collective_axes: aval.size})
if _is_block_local_scope(collective_axes, ctx.axis_names):
rs += Resources(scoped_gmem_semaphores={collective_axes: aval.size})
else:
raise ValueError(
"Only thread-collective allocations are supported in run_scoped. To"
" allocate global semaphores, use pl.get_global."
)
else:
raise NotImplementedError(
f"Unsupported memory space: {aval.memory_space}")
Expand Down Expand Up @@ -2567,6 +2583,10 @@ def _run_scoped_lowering_rule(
# Make sure everyone has exited previous scoped allocations. Note that we
# don't synchronize when we exit the allocation, but only when we might want
# to reuse its memory again.
if collective_axes and collective_axes != (wg_axis,):
raise ValueError(
"Only thread-collective allocations are supported in run_scoped."
)
if is_multithreaded and is_thread_collective:
gpu_dialect.barrier()
with contextlib.ExitStack() as alloc_stack:
Expand Down Expand Up @@ -2705,6 +2725,33 @@ def _run_scoped_lowering_rule(
return outs


@_register_resource_estimator(primitives.get_global_p)
def _get_global_resource_estimator(
ctx: ResourceEstimatorContext, *, what
) -> Resources:
if what.memory_space == gpu_core.GMEM and jnp.issubdtype(
what.dtype, pallas_core.semaphore
):
collective_axes = tuple(ctx.axis_names)
return Resources(scoped_gmem_semaphores={collective_axes: what.size})
raise NotImplementedError(f"get_global only supports semaphores, got {what}")


@register_lowering_rule(primitives.get_global_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(
primitives.get_global_p, mgpu.LoweringSemantics.Warpgroup
)
def _get_global_lowering_rule(ctx: LoweringRuleContext, *, what):
if what.memory_space == gpu_core.GMEM and jnp.issubdtype(
what.dtype, pallas_core.semaphore
):
collective_axes = tuple(ctx.module_ctx.axis_names)
return ctx.module_ctx.reserve_semaphores(
what.shape, collective_axes=collective_axes
).__enter__()
raise NotImplementedError(f"get_global only supports semaphores, got {what}")


@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane)
@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup)
def _run_state_lowering_rule(
Expand Down Expand Up @@ -3386,9 +3433,14 @@ def _semaphore_read_lowering_rule(ctx: LoweringRuleContext, *args, args_tree):
raise NotImplementedError(f"Unhandled transforms for semaphore_read: {transforms}")
sem_ptr = mgpu.utils.memref_ptr(sem)
i32_ty = ir.IntegerType.get_signless(32)
return llvm_dialect.inline_asm(
i32_ty, [sem_ptr], "ld.acquire.sys.u32 $0,[$1];", "=r,l", has_side_effects=True,
result = llvm_dialect.inline_asm(
i32_ty,
[sem_ptr],
"ld.acquire.sys.u32 $0,[$1];",
"=r,l",
has_side_effects=True,
)
return _ensure_fa(result, jnp.int32)


@register_lowering_rule(primitives.semaphore_signal_p, mgpu.LoweringSemantics.Lane)
Expand Down
41 changes: 41 additions & 0 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,47 @@ def _lower_fun(*lower_fun_args):
return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args)


get_global_p = jax_core.Primitive("get_global")
get_global_p.multiple_results = False


def get_global(what: pallas_core.ScratchShape) -> jax.Array:
"""Returns a global reference that persists across all kernel invocations.

Each call to get_global returns a different and unique reference, but one that
is stable across invocations of the kernel body.

Args:
what: The reference type to allocate. Each backend has its own set of
reference types (e.g., `plgpu.SemaphoreType.REGULAR` for GPU).

Example::

sem_ref = pl.get_global(plgpu.SemaphoreType.REGULAR)
pl.semaphore_signal(sem_ref)
pl.semaphore_wait(sem_ref)
"""
ref_aval = what.get_ref_aval()
return get_global_p.bind(what=ref_aval)


@get_global_p.def_abstract_eval
def _get_global_abstract_eval(*, what):
return what


def _get_global_discharge_rule(in_avals, out_avals, *, what):
del in_avals, out_avals, what
raise NotImplementedError(
"get_global discharge is not supported in interpret mode."
)


state_discharge.register_discharge_rule(get_global_p)(
_get_global_discharge_rule
)


def _get_ref_and_transforms(ref):
if isinstance(ref, state.TransformedRef):
return ref.ref, ref.transforms
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from jax._src.pallas.core import Element as Element
from jax._src.pallas.core import GridSpec as GridSpec
from jax._src.pallas.core import lower_as_mlir as lower_as_mlir
from jax._src.pallas.core import MemorySpace as MemorySpace
from jax._src.pallas.core import MemoryRef as MemoryRef
from jax._src.pallas.core import MemorySpace as MemorySpace
from jax._src.pallas.core import no_block_spec as no_block_spec
from jax._src.pallas.core import semaphore as semaphore
from jax._src.pallas.core import Squeezed as Squeezed
Expand All @@ -57,6 +57,7 @@
from jax._src.pallas.primitives import debug_print as debug_print
from jax._src.pallas.primitives import DeviceIdType as DeviceIdType
from jax._src.pallas.primitives import dot as dot
from jax._src.pallas.primitives import get_global as get_global
from jax._src.pallas.primitives import load as _deprecated_load
from jax._src.pallas.primitives import max_contiguous as max_contiguous
from jax._src.pallas.primitives import multiple_of as multiple_of
Expand Down
12 changes: 3 additions & 9 deletions jax/experimental/pallas/ops/gpu/collective_matmul_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def all_gather_lhs_matmul(

num_sms = jax.devices()[0].core_count # 132 for H100 SXM GPUs.

def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem, received_sem):
def kernel_body(lhs_local_ref, rhs_ref, out_ref, scratch_ref, out_smem):
received_sem = pl.get_global(plgpu.SemaphoreType.REGULAR)
wg_idx = lax.axis_index("wg")
dev_id = lax.axis_index(axis_name)
send_dev_id = lax.rem(dev_id + axis_size - 1, axis_size)
Expand Down Expand Up @@ -145,13 +146,6 @@ def _device_loop(device_offset):
# Make sure all copies are fully done.
plgpu.wait_smem_to_gmem(0, wait_read_only=True)

def kernel_entry(*args):
return pl.run_scoped(
functools.partial(kernel_body, *args),
received_sem=plgpu.SemaphoreType.REGULAR,
collective_axes=("cluster_grid", "cluster", "wg"),
)

num_out_slots = min(2, (tile_m * tile_n) // (epi_tile_m * epi_tile_n))
out_swizzle = plgpu.find_swizzle(epi_tile_n * jnp.dtype(dtype).itemsize * 8)
out_swizzle_elems = out_swizzle // jnp.dtype(dtype).itemsize
Expand All @@ -160,7 +154,7 @@ def kernel_entry(*args):
plgpu.SwizzleTransform(out_swizzle),
)
result, _ = plgpu.kernel(
kernel_entry,
kernel_body,
out_shape=[
# The output, with its M dimension all-gathered.
jax.ShapeDtypeStruct((axis_size * m_shard, n_shard), dtype),
Expand Down
Loading
Loading