Skip to content

Commit c4cc94a

Browse files
[Mosaic GPU] Add warpgroup lowering for RunState in Pallas.
After this change we no longer skip tests that required 'RunState`. This necessitated a small fix in the pallas lowering of `while` and also enabling multiple i32 register bundling in the `optimization_barrier` lowering. PiperOrigin-RevId: 745065173
1 parent d12cbff commit c4cc94a

File tree

2 files changed

+31
-22
lines changed

2 files changed

+31
-22
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,6 +2034,7 @@ def _run_scoped_lowering_rule(
20342034

20352035

20362036
@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Lane)
2037+
@register_lowering_rule(discharge.run_state_p, mgpu.LoweringSemantics.Warpgroup)
20372038
def _run_state_lowering_rule(
20382039
ctx: LoweringRuleContext,
20392040
*args,
@@ -2051,7 +2052,12 @@ def _run_state_lowering_rule(
20512052
for arg, v, out_aval in zip(args, jaxpr.invars, ctx.avals_out):
20522053
aval = v.aval
20532054
if isinstance(aval, gpu_core.WGMMAAbstractAccumulatorRef):
2054-
new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg))
2055+
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Warpgroup:
2056+
arg = mgpu.dialect.optimization_barrier([arg])
2057+
nvvm_dialect.wgmma_fence_aligned()
2058+
new_input_vals.append(arg)
2059+
else:
2060+
new_input_vals.append(mgpu.WGMMAAccumulator.from_registers(arg))
20552061
should_discharge.append(True)
20562062
assert isinstance(out_aval, jax_core.ShapedArray)
20572063
else:
@@ -2273,18 +2279,19 @@ def _while_lowering_rule(
22732279
ctx.module_ctx, ctx.launch_ctx, body_jaxpr.jaxpr, body_args
22742280
)
22752281
loop_out = [*map(_ensure, loop_out, carry_avals)]
2276-
for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)):
2277-
if _is_acc(carry_fa) != _is_acc(out_fa):
2278-
raise ValueError(
2279-
f"The loop body output has unexpected accumulator type: output[{idx}]"
2280-
f" is {out_fa}, when it should be {carry_fa}."
2281-
)
2282+
if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
2283+
for idx, (carry_fa, out_fa) in enumerate(zip(carry, loop_out)):
2284+
if _is_acc(carry_fa) != _is_acc(out_fa):
2285+
raise ValueError(
2286+
f"The loop body output has unexpected accumulator type:"
2287+
f" output[{idx}] is {out_fa}, when it should be {carry_fa}."
2288+
)
22822289

2283-
if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout:
2284-
raise ValueError(
2285-
f"The loop body output has unexpected layout: output[{idx}] has"
2286-
f" layout {out_fa.layout}, when it should be {carry_fa.layout}."
2287-
)
2290+
if not _is_acc(out_fa) and carry_fa.layout != out_fa.layout:
2291+
raise ValueError(
2292+
f"The loop body output has unexpected layout: output[{idx}] has"
2293+
f" layout {out_fa.layout}, when it should be {carry_fa.layout}."
2294+
)
22882295
scf_dialect.yield_(
22892296
carry_treedef.flatten_up_to(loop_out) if loop_out else []
22902297
)

tests/pallas/mosaic_gpu_test.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from jax._src.pallas.mosaic_gpu import lowering as mgpu_lowering
3333
from jax._src.pallas.mosaic_gpu import pipeline as mgpu_pipeline
3434
from jax._src.pallas.mosaic_gpu import primitives as mgpu_primitives
35-
from jax._src.state import discharge
3635
from jax.experimental import pallas as pl
3736
import jax.experimental.mosaic.gpu as mgpu
3837
from jax.experimental.pallas import mosaic_gpu as plgpu
@@ -1528,7 +1527,6 @@ def test_missing_primitive_lowerings_are_tracked(self):
15281527
mgpu_primitives.layout_cast_p,
15291528
mgpu_primitives.load_p,
15301529
lax.slice_p,
1531-
discharge.run_state_p,
15321530
}
15331531

15341532
self.assertSetEqual(actual_missing_primitives, expected_missing_primitives)
@@ -1538,10 +1536,14 @@ class PallasCallSm90ATest(PallasSm90ATest):
15381536

15391537
@parameterized.parameters(False, True)
15401538
def test_fori_loop_accumulator(self, force_while):
1541-
# ``pl.run_state`` is not supported in WG semantics.
1542-
self.skip_if_wg_semantics()
1543-
1544-
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
1539+
if force_while:
1540+
# Layout inference and lowering for 'while' are not yet implemented for
1541+
# warpgroup semantics.
1542+
self.skip_if_wg_semantics()
1543+
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
1544+
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
1545+
else:
1546+
transforms = ()
15451547
@functools.partial(
15461548
self.pallas_call,
15471549
in_specs=[plgpu.GPUBlockSpec((64, 64), transforms=transforms)],
@@ -1733,9 +1735,6 @@ def scope(acc_ref):
17331735
np.testing.assert_allclose(res, a @ b, rtol=1e-3)
17341736

17351737
def test_wgmma_registers_init(self):
1736-
# ``pl.run_state`` is not supported in WG semantics.
1737-
self.skip_if_wg_semantics()
1738-
17391738
def kernel(a_ref, b_ref, i_ref, o_ref):
17401739
def scope(acc_ref):
17411740
plgpu.wgmma(acc_ref, a_ref[...], b_ref)
@@ -1746,7 +1745,10 @@ def scope(acc_ref):
17461745
b = jax.random.uniform(key2, shape=(128, 192), dtype=jnp.float16)
17471746
i = jax.random.uniform(key3, shape=(64, 192), dtype=jnp.float16) * 10
17481747

1749-
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
1748+
if self.LOWERING_SEMANTICS == plgpu.LoweringSemantics.Lane:
1749+
transforms = (plgpu.TilingTransform((8, 64)), plgpu.SwizzleTransform(128))
1750+
else:
1751+
transforms = ()
17501752
res = self.pallas_call(
17511753
kernel,
17521754
in_specs=[

0 commit comments

Comments
 (0)