Skip to content

Commit 757c92d

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Expose async SMEM->TMEM copies for arbitrary data
PiperOrigin-RevId: 872782950
1 parent 4a28ce1 commit 757c92d

File tree

3 files changed

+171
-1
lines changed

3 files changed

+171
-1
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3483,7 +3483,7 @@ def _async_copy_to_tmem_abstract_eval(smem_ref, tmem_ref, *_args, **_kwargs):
34833483
raise ValueError("async_copy_scales_to_tmem source must be an SMEM ref")
34843484
if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM:
34853485
raise ValueError("async_copy_scales_to_tmem target must be a TMEM ref")
3486-
return (), {gpu_core._memory_effect}
3486+
return (), {state_types.ReadEffect(0), state_types.WriteEffect(1)}
34873487

34883488
def _async_copy_to_tmem_lowering_rule(
34893489
impl, ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves, smem_tree, tmem_tree, collective_axis
@@ -3552,6 +3552,145 @@ def _async_copy_sparse_metadata_to_tmem_lowering_rule(*args, **kwargs):
35523552
)
35533553

35543554

3555+
async_copy_smem_to_tmem_p = jax_core.Primitive("async_copy_smem_to_tmem")
3556+
async_copy_smem_to_tmem_p.multiple_results = True
3557+
3558+
3559+
def async_copy_smem_to_tmem(
3560+
smem_ref: _Ref,
3561+
tmem_ref: _Ref,
3562+
collective_axis: AxisName | None = None,
3563+
):
3564+
"""Copies data from SMEM to TMEM using the tcgen05.cp instruction.
3565+
3566+
The source SMEM ref must have tiling and swizzle transforms applied. The
3567+
destination TMEM ref must use packed layout (i.e. ``packed=True`` for sub-32b
3568+
types).
3569+
3570+
The copy is performed asynchronously. It can be awaited by calling
3571+
``tcgen05_commit_arrive`` and waiting on the specified barrier. No
3572+
synchronization is necessary if the target of the copy is used by a
3573+
tcgen05_mma operation.
3574+
3575+
Args:
3576+
smem_ref: The SMEM reference to copy from.
3577+
tmem_ref: The TMEM reference to copy into.
3578+
collective_axis: The name of the cluster axis along which the
3579+
copy should be performed collectively. The cluster axis should have a
3580+
size of exactly 2, and must be on the minormost cluster axis.
3581+
"""
3582+
smem_ref, smem_transforms = state_primitives.get_ref_and_transforms(
3583+
smem_ref, None, "async_copy_smem_to_tmem"
3584+
)
3585+
flat_smem_transforms, smem_transforms_treedef = tree_util.tree_flatten(
3586+
smem_transforms
3587+
)
3588+
tmem_ref, tmem_transforms = state_primitives.get_ref_and_transforms(
3589+
tmem_ref, None, "async_copy_smem_to_tmem"
3590+
)
3591+
flat_tmem_transforms, tmem_transforms_treedef = tree_util.tree_flatten(
3592+
tmem_transforms
3593+
)
3594+
async_copy_smem_to_tmem_p.bind(
3595+
smem_ref, tmem_ref, *flat_smem_transforms, *flat_tmem_transforms,
3596+
smem_tree=smem_transforms_treedef, tmem_tree=tmem_transforms_treedef,
3597+
collective_axis=collective_axis,
3598+
)
3599+
3600+
3601+
@async_copy_smem_to_tmem_p.def_effectful_abstract_eval
3602+
def _async_copy_smem_to_tmem_abstract_eval(
3603+
smem_ref, tmem_ref, *args, smem_tree, **_kwargs
3604+
):
3605+
if smem_ref.memory_space != gpu_core.MemorySpace.SMEM:
3606+
raise ValueError("async_copy_smem_to_tmem source must be an SMEM ref")
3607+
if tmem_ref.memory_space != gpu_core.MemorySpace.TMEM:
3608+
raise ValueError("async_copy_smem_to_tmem target must be a TMEM ref")
3609+
smem_transforms = jax.tree.unflatten(smem_tree, args[:smem_tree.num_leaves])
3610+
smem_aval = smem_ref
3611+
for t in smem_transforms:
3612+
smem_aval = t.transform_type(smem_aval)
3613+
if smem_aval.dtype != tmem_ref.dtype:
3614+
raise ValueError(
3615+
f"Expected SMEM element type ({smem_aval.dtype}) to equal the TMEM"
3616+
f" element type ({tmem_ref.dtype})"
3617+
)
3618+
if smem_aval.shape != tmem_ref.shape:
3619+
raise ValueError(
3620+
f"Expected SMEM reference shape {smem_aval.shape} to equal the TMEM"
3621+
f" reference shape {tmem_ref.shape}"
3622+
)
3623+
return (), {state_types.ReadEffect(0), state_types.WriteEffect(1)}
3624+
3625+
3626+
@lowering.register_lowering_rule(
3627+
async_copy_smem_to_tmem_p, mgpu.LoweringSemantics.Lane
3628+
)
3629+
@lowering.register_lowering_rule(
3630+
async_copy_smem_to_tmem_p,
3631+
mgpu.LoweringSemantics.Lane,
3632+
gpu_core.PrimitiveSemantics.Warp,
3633+
)
3634+
def _async_copy_smem_to_tmem_lowering_rule(
3635+
ctx: lowering.LoweringRuleContext, smem_ref, tmem_ref, *leaves,
3636+
smem_tree, tmem_tree, collective_axis,
3637+
):
3638+
assert isinstance(tmem_ref, tcgen05.TMEMRef)
3639+
smem_leaves, tmem_leaves = util.split_list(leaves, [smem_tree.num_leaves])
3640+
smem_transforms = jax.tree.unflatten(smem_tree, smem_leaves)
3641+
tmem_transforms = jax.tree.unflatten(tmem_tree, tmem_leaves)
3642+
smem_aval = ctx.avals_in[0]
3643+
assert isinstance(smem_aval, state_types.AbstractRef)
3644+
tmem_aval = ctx.avals_in[1]
3645+
assert isinstance(tmem_aval, state_types.AbstractRef)
3646+
transform_avals = util.split_list(
3647+
ctx.avals_in[2:], [smem_tree.num_leaves]
3648+
)
3649+
smem_transform_avals = smem_tree.unflatten(transform_avals[0])
3650+
tmem_transform_avals = tmem_tree.unflatten(transform_avals[1])
3651+
smem_ref, transformed_smem_aval, smem_transforms = lowering._handle_transforms(
3652+
ctx, smem_aval, smem_ref, smem_transform_avals, smem_transforms,
3653+
handle_transposes=False
3654+
)
3655+
tmem_ref, _, tmem_transforms = lowering._handle_transforms(
3656+
ctx, tmem_aval, tmem_ref, tmem_transform_avals, tmem_transforms
3657+
)
3658+
match smem_transforms:
3659+
case (
3660+
gpu_core.UnswizzleRef(swizzle),
3661+
gpu_core.UntilingTransform(tiling),
3662+
):
3663+
pass
3664+
case _:
3665+
raise NotImplementedError(
3666+
f"Unsupported transforms for SMEM ref: {smem_transforms}"
3667+
)
3668+
swizzle_elems = 8 * swizzle // dtypes.itemsize_bits(transformed_smem_aval.dtype)
3669+
if tiling != (8, swizzle_elems):
3670+
raise ValueError(
3671+
f"Tiling does not fit swizzle: expected (8, {swizzle_elems}), but got"
3672+
f" {tiling}"
3673+
)
3674+
if tmem_transforms:
3675+
raise NotImplementedError(
3676+
f"Unimplemented transforms for TMEM refs: {tmem_transforms}"
3677+
)
3678+
3679+
predicate = ctx.module_ctx.single_lane_predicate
3680+
if collective_axis is not None:
3681+
is_leader_block = _collective_mma_predicate(ctx, collective_axis)
3682+
predicate = arith_dialect.andi(predicate, is_leader_block)
3683+
collective = True
3684+
else:
3685+
collective = False
3686+
3687+
with mgpu.when(predicate):
3688+
tcgen05.async_copy_smem_to_tmem(
3689+
smem_ref, tmem_ref, swizzle, collective=collective
3690+
)
3691+
return ()
3692+
3693+
35553694
semaphore_signal_parallel_p = jax_core.Primitive('semaphore_signal_parallel')
35563695
semaphore_signal_parallel_p.multiple_results = True
35573696

jax/experimental/pallas/mosaic_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from jax._src.pallas.mosaic_gpu.pipeline import emit_pipeline_warp_specialized as emit_pipeline_warp_specialized
5454
from jax._src.pallas.mosaic_gpu.pipeline import PipelinePipeline as PipelinePipeline
5555
from jax._src.pallas.mosaic_gpu.primitives import async_copy_scales_to_tmem as async_copy_scales_to_tmem
56+
from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_tmem as async_copy_smem_to_tmem
5657
from jax._src.pallas.mosaic_gpu.primitives import async_copy_sparse_metadata_to_tmem as async_copy_sparse_metadata_to_tmem
5758
from jax._src.pallas.mosaic_gpu.primitives import async_load_tmem as async_load_tmem
5859
from jax._src.pallas.mosaic_gpu.primitives import async_prefetch as async_prefetch

tests/pallas/mosaic_gpu_test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3078,6 +3078,7 @@ def test_missing_primitive_lowerings_are_tracked(self):
30783078
wg_wg_lowered_primitives)
30793079
expected_missing_primitives = {
30803080
mgpu_primitives.async_copy_scales_to_tmem_p,
3081+
mgpu_primitives.async_copy_smem_to_tmem_p,
30813082
mgpu_primitives.async_copy_sparse_metadata_to_tmem_p,
30823083
mgpu_primitives.wait_load_tmem_p,
30833084
mgpu_primitives.semaphore_signal_parallel_p,
@@ -4327,6 +4328,35 @@ def kernel(a_gmem, b_gmem, out_gmem,
43274328
result = f(x, y)
43284329
np.testing.assert_allclose(result, x @ y, rtol=1e-3)
43294330

4331+
def test_async_copy_smem_to_tmem(self):
4332+
self.skip_if_wg_semantics()
4333+
dtype = jnp.float16
4334+
swizzle = 128
4335+
m, n = 128, 128
4336+
transforms = self.default_transforms(swizzle=swizzle, dtype=dtype)
4337+
4338+
def kernel(x_gmem, y_gmem, smem, tma_barrier, mma_barrier, tmem):
4339+
plgpu.copy_gmem_to_smem(x_gmem, smem, tma_barrier)
4340+
plgpu.barrier_wait(tma_barrier)
4341+
plgpu.async_copy_smem_to_tmem(smem, tmem)
4342+
plgpu.tcgen05_commit_arrive(mma_barrier)
4343+
plgpu.barrier_wait(mma_barrier)
4344+
y_gmem[...] = plgpu.async_load_tmem(tmem)
4345+
4346+
f = self.kernel(
4347+
kernel,
4348+
out_shape=jax.ShapeDtypeStruct((m, n), dtype),
4349+
scratch_shapes=[
4350+
plgpu.SMEM((m, n), dtype, transforms=transforms),
4351+
plgpu.Barrier(),
4352+
plgpu.Barrier(orders_tensor_core=True),
4353+
plgpu.TMEM((m, n), dtype, packed=True),
4354+
],
4355+
)
4356+
x = jax.random.uniform(jax.random.key(0), shape=(m, n), dtype=dtype)
4357+
result = jax.block_until_ready(f(x))
4358+
np.testing.assert_array_equal(result, x)
4359+
43304360
def test_matmul_with_sliced_accumulator(self):
43314361
dtype = jnp.bfloat16
43324362
shape = (128, 128)

0 commit comments

Comments
 (0)