Skip to content

Commit ef7df1a

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[pallas_mgpu] Allow trees (eg tuples) to be returned from cond_p expressions.
PiperOrigin-RevId: 700136799
1 parent ebea435 commit ef7df1a

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,8 +1514,24 @@ def _while_lowering_rule(
15141514
@register_lowering_rule(lax.cond_p)
15151515
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15161516
index_aval, *_arg_avals = ctx.avals_in
1517+
1518+
# We need the branch return mlir types in order to construct the
1519+
# switch operation. To avoid leaking information about what kind of
1520+
# mlir types are internal to FragmentedArrays and other mgpu types,
1521+
# we run one of the branches in a dummy module that we throw away to
1522+
# extract the return types
1523+
with ir.InsertionPoint(ir.Module.create().body):
1524+
outs = lower_jaxpr_to_mosaic_gpu(
1525+
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args
1526+
)
1527+
yielded = [
1528+
_ensure_ir_value(out, aval.dtype) or out
1529+
for out, aval in zip(outs, ctx.avals_out)
1530+
]
1531+
yielded_leaves, _ = jax.tree.flatten(yielded)
1532+
15171533
switch_op = scf_dialect.IndexSwitchOp(
1518-
map(mgpu_utils.dtype_to_ir_type, ctx.avals_out),
1534+
[v.type for v in yielded_leaves],
15191535
_as_index(_ensure_ir_value(index, index_aval.dtype)),
15201536
ir.DenseI64ArrayAttr.get(range(len(branches) - 1)),
15211537
num_caseRegions=len(branches) - 1,
@@ -1527,16 +1543,27 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15271543
regions = list(switch_op.regions)
15281544
# Move the default region to the back.
15291545
regions = regions[1:] + regions[:1]
1546+
treedef = None
15301547
for branch, region in zip(branches, regions):
15311548
with ir.InsertionPoint(region.blocks.append()):
15321549
outs = lower_jaxpr_to_mosaic_gpu(
1533-
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args
1550+
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts
15341551
)
1535-
scf_dialect.yield_([
1536-
_ensure_ir_value(out, aval.dtype)
1552+
1553+
yielded = [
1554+
_ensure_ir_value(out, aval.dtype) or out
15371555
for out, aval in zip(outs, ctx.avals_out)
1538-
])
1539-
return list(switch_op.results)
1556+
]
1557+
yielded_leaves, yielded_treedef = jax.tree.flatten(yielded)
1558+
if treedef is None:
1559+
treedef = yielded_treedef
1560+
else:
1561+
assert treedef == yielded_treedef
1562+
1563+
scf_dialect.yield_(yielded_leaves)
1564+
1565+
assert treedef is not None
1566+
return treedef.unflatten(list(switch_op.results))
15401567

15411568

15421569
@register_lowering_rule(lax.bitcast_convert_type_p)

tests/pallas/mosaic_gpu_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,24 @@ def kernel(x_ref, o_ref):
749749

750750
self.assertIn("acc * 2:", output())
751751

752+
def test_cond_returning_array(self):
753+
@functools.partial(
754+
pl.pallas_call,
755+
out_shape=jax.ShapeDtypeStruct([256], jnp.int32),
756+
)
757+
def kernel(x_ref, o_ref):
758+
acc = x_ref[...].sum()
759+
acc2, acc = jax.lax.cond(
760+
acc % 2 == 0,
761+
lambda: (acc * 2, acc),
762+
lambda: (acc, acc * 2),
763+
)
764+
o_ref[...] = jnp.broadcast_to(acc + acc2, o_ref.shape)
765+
766+
x = jnp.arange(256)
767+
np.testing.assert_array_equal(kernel(x), jnp.broadcast_to(jnp.sum(x) * 3, [256]))
768+
769+
752770
@parameterized.parameters(jnp.float16, jnp.float32)
753771
def test_wgmma(self, dtype):
754772
self.skip_unless_sm90a()

0 commit comments

Comments
 (0)