Skip to content

Commit f833891

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Add support for passing in WGMMA lhs from registers
PiperOrigin-RevId: 688117316
1 parent f08801b commit f833891

File tree

2 files changed

+63
-26
lines changed

2 files changed

+63
-26
lines changed

jax/_src/pallas/mosaic_gpu/primitives.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ class _WGMMAPipelineEffect(effects.Effect):
364364

365365
def wgmma(
366366
acc: gpu_core.WGMMAAbstractAccumulatorRef,
367-
a: pallas_core.TransformedRef,
367+
a,
368368
b: pallas_core.TransformedRef,
369369
) -> None:
370370
"""Performs and asynchronous warp group matmul-accumulate on the given references.
@@ -395,12 +395,16 @@ def wgmma(
395395
if a.dtype != b.dtype:
396396
raise ValueError(f"Mixed input dtypes for matrix multiplication unsupported: lhs={a.dtype}, rhs={b.dtype}")
397397

398-
a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms)
398+
if isinstance(a, pallas_core.TransformedRef):
399+
a_transforms_leaves, a_transforms_tree = jax.tree.flatten(a.transforms)
400+
a = a.ref
401+
else:
402+
a_transforms_leaves, a_transforms_tree = [], None
399403
b_transforms_leaves, b_transforms_tree = jax.tree.flatten(b.transforms)
400404

401405
wgmma_ref_p.bind(
402406
acc,
403-
a.ref,
407+
a,
404408
b.ref,
405409
*a_transforms_leaves,
406410
*b_transforms_leaves,
@@ -411,15 +415,15 @@ def wgmma(
411415

412416
@wgmma_ref_p.def_effectful_abstract_eval
413417
def _wgmma_ref_effectful_abstract_eval(acc_aval, a_aval, b_aval, *_, **params):
414-
del a_aval, b_aval, params
418+
del b_aval, params
415419
if not isinstance(acc_aval, gpu_core.WGMMAAbstractAccumulatorRef):
416420
raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc_aval}")
417421
return (), {
418422
_wgmma_pipeline_effect,
419423
state.WriteEffect(0),
420424
state.ReadEffect(0),
421-
state.ReadEffect(1),
422425
state.ReadEffect(2),
426+
*([state.ReadEffect(1)] if isinstance(a_aval, state.AbstractRef) else [])
423427
}
424428

425429

@@ -444,23 +448,31 @@ def _wgmma_lowering(
444448
b_transforms_tree,
445449
):
446450
_, a_aval, *_ = ctx.avals_in
447-
a_transforms_leaves, b_transforms_leaves = util.split_list(
448-
transforms_leaves, [a_transforms_tree.num_leaves]
449-
)
450-
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
451-
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
451+
lhs_swizzle = None
452+
if a_transforms_tree is not None:
453+
a_transforms_leaves, b_transforms_leaves = util.split_list(
454+
transforms_leaves, [a_transforms_tree.num_leaves]
455+
)
456+
a_transforms = a_transforms_tree.unflatten(a_transforms_leaves)
457+
a, a_transforms = lowering._handle_indexing(a, a_transforms)
458+
match a_transforms:
459+
case (gpu_core.UnswizzleRef(lhs_swizzle), gpu_core.UntileRef(tiling)):
460+
swizzle_elems = lhs_swizzle // a_aval.dtype.itemsize
461+
if tiling != (64, swizzle_elems):
462+
raise NotImplementedError("WGMMA lhs tiling does not fit swizzle")
463+
case _:
464+
raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.")
465+
else:
466+
b_transforms_leaves = transforms_leaves # type: ignore
467+
if not isinstance(a, mgpu.FragmentedArray):
468+
raise ValueError(
469+
"When WGMMA lhs is passed in as a ref, it must be transformed by"
470+
" swizzling and tiling appropriately."
471+
)
452472

453-
a, a_transforms = lowering._handle_indexing(a, a_transforms)
473+
b_transforms = b_transforms_tree.unflatten(b_transforms_leaves)
454474
b, b_transforms = lowering._handle_indexing(b, b_transforms)
455475

456-
match a_transforms:
457-
case (gpu_core.UnswizzleRef(swizzle), gpu_core.UntileRef(tiling)):
458-
swizzle_elems = swizzle // a_aval.dtype.itemsize
459-
if tiling != (64, swizzle_elems):
460-
raise NotImplementedError("WGMMA lhs tiling does not fit swizzle")
461-
case _:
462-
raise ValueError(f"WGMMA lhs has unsupported transforms: {a_transforms}.")
463-
464476
match b_transforms:
465477
case (gpu_core.UnswizzleRef(rhs_swizzle), gpu_core.UntileRef(rhs_tiling)):
466478
rhs_transpose = False
@@ -474,16 +486,18 @@ def _wgmma_lowering(
474486
case _:
475487
raise ValueError(f"WGMMA rhs has unsupported transforms: {b_transforms}.")
476488

477-
if rhs_swizzle != swizzle:
478-
raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle")
479-
if rhs_tiling != (swizzle_elems, swizzle_elems):
480-
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")
489+
if lhs_swizzle is not None:
490+
swizzle_elems = rhs_swizzle // a_aval.dtype.itemsize
491+
if rhs_swizzle != lhs_swizzle:
492+
raise NotImplementedError("WGMMA rhs swizzle must match lhs swizzle")
493+
if rhs_tiling != (swizzle_elems, swizzle_elems):
494+
raise NotImplementedError("WGMMA rhs tiling does not fit swizzle")
481495

482496
new_acc = mgpu.wgmma(
483497
acc,
484498
a,
485499
b,
486-
swizzle=swizzle,
500+
swizzle=rhs_swizzle,
487501
b_order=mgpu.WGMMALayout.COL_MAJOR
488502
if rhs_transpose
489503
else mgpu.WGMMALayout.ROW_MAJOR,
@@ -493,12 +507,12 @@ def _wgmma_lowering(
493507

494508

495509
@wgmma_p.def_effectful_abstract_eval
496-
def _wgmma_effectful_abstract_eval(acc, *args, **kwargs):
510+
def _wgmma_effectful_abstract_eval(acc, lhs_ref, *args, **kwargs):
497511
del args, kwargs
498512
return acc, {
499513
_wgmma_pipeline_effect,
500-
state.ReadEffect(1),
501514
state.ReadEffect(2),
515+
*([state.ReadEffect(1)] if isinstance(lhs_ref, state.AbstractRef) else [])
502516
}
503517

504518
wgmma_wait_p = jax_core.Primitive("wgmma_wait")

tests/pallas/mosaic_gpu_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,29 @@ def scope(acc_ref):
705705
res, a @ (b.T if rhs_transpose else b), rtol=1e-3
706706
)
707707

708+
def test_wgmma_registers(self):
709+
def kernel(a_ref, b_ref, o_ref):
710+
def scope(acc_ref):
711+
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
712+
return acc_ref[...]
713+
o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 192), jnp.float32))
714+
715+
key1, key2 = jax.random.split(jax.random.key(42), 2)
716+
a = jax.random.uniform(key1, shape=(64, 128), dtype=jnp.float16)
717+
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
718+
719+
transforms = (plgpu.TilingTransform((64, 64)), plgpu.SwizzleTransform(128))
720+
res = pl.pallas_call(
721+
kernel,
722+
in_specs=[
723+
plgpu.GPUBlockSpec((64, 128), lambda: (0, 0), transforms=transforms),
724+
plgpu.GPUBlockSpec((128, 192), lambda: (0, 0), transforms=transforms),
725+
],
726+
out_specs=plgpu.GPUBlockSpec((64, 192), lambda: (0, 0)),
727+
out_shape=jax.ShapeDtypeStruct((64, 192), jnp.float32),
728+
)(a, b)
729+
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
730+
708731
def test_wgmma_sliced_ref(self):
709732
def kernel(a_ref, b_ref, o_ref):
710733
def scope(acc_ref):

0 commit comments

Comments
 (0)