Skip to content

Commit 1ddba9b

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu_pallas] Optionally pass default value instead of raising an error when trying to ensure ir Value.
PiperOrigin-RevId: 702672662
1 parent cb2cf56 commit 1ddba9b

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
142142
# Assume that unsupported primitives are neutral wrt resource usage.
143143
continue
144144
rs |= rule(*(invar.aval for invar in eqn.invars), **eqn.params)
145+
145146
return rs
146147

147148

@@ -1592,6 +1593,15 @@ def _while_lowering_rule(
15921593
def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15931594
index_aval, *_arg_avals = ctx.avals_in
15941595

1596+
def _yielded_values(outs, avals):
1597+
ret = []
1598+
for out, aval in zip(outs, avals):
1599+
if isinstance(out, mgpu.FragmentedArray):
1600+
ret.append(out)
1601+
else:
1602+
ret.append(_ensure_ir_value(out, aval.dtype))
1603+
return ret
1604+
15951605
# We need the branch return mlir types in order to construct the
15961606
# switch operation. To avoid leaking information about what kind of
15971607
# mlir types are internal to FragmentedArrays and other mgpu types,
@@ -1601,10 +1611,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
16011611
outs = lower_jaxpr_to_mosaic_gpu(
16021612
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args
16031613
)
1604-
yielded_types, _ = jax.tree.flatten([
1605-
(_ensure_ir_value(out, aval.dtype) or out).type
1606-
for out, aval in zip(outs, ctx.avals_out)
1607-
])
1614+
yielded_types = [v.type for v in jax.tree.leaves(_yielded_values(outs, ctx.avals_out))]
16081615

16091616
switch_op = scf_dialect.IndexSwitchOp(
16101617
yielded_types,
@@ -1626,11 +1633,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
16261633
ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args, consts=branch.consts
16271634
)
16281635

1629-
yielded = [
1630-
_ensure_ir_value(out, aval.dtype) or out
1631-
for out, aval in zip(outs, ctx.avals_out)
1632-
]
1633-
yielded_leaves, yielded_treedef = jax.tree.flatten(yielded)
1636+
yielded_leaves, yielded_treedef = jax.tree.flatten(_yielded_values(outs, ctx.avals_out))
16341637
if treedef is None:
16351638
treedef = yielded_treedef
16361639
else:

0 commit comments

Comments
 (0)