@@ -142,6 +142,7 @@ def _estimate_resources(jaxpr: jax_core.Jaxpr) -> Resources:
142142 # Assume that unsupported primitives are neutral wrt resource usage.
143143 continue
144144 rs |= rule (* (invar .aval for invar in eqn .invars ), ** eqn .params )
145+
145146 return rs
146147
147148
@@ -1592,6 +1593,15 @@ def _while_lowering_rule(
15921593def _cond_lowering_rule (ctx : LoweringRuleContext , index , * args , branches ):
15931594 index_aval , * _arg_avals = ctx .avals_in
15941595
1596+ def _yielded_values (outs , avals ):
1597+ ret = []
1598+ for out , aval in zip (outs , avals ):
1599+ if isinstance (out , mgpu .FragmentedArray ):
1600+ ret .append (out )
1601+ else :
1602+ ret .append (_ensure_ir_value (out , aval .dtype ))
1603+ return ret
1604+
15951605 # We need the branch return mlir types in order to construct the
15961606 # switch operation. To avoid leaking information about what kind of
15971607 # mlir types are internal to FragmentedArrays and other mgpu types,
@@ -1601,10 +1611,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
16011611 outs = lower_jaxpr_to_mosaic_gpu (
16021612 ctx .module_ctx , ctx .launch_ctx , branches [0 ].jaxpr , args
16031613 )
1604- yielded_types , _ = jax .tree .flatten ([
1605- (_ensure_ir_value (out , aval .dtype ) or out ).type
1606- for out , aval in zip (outs , ctx .avals_out )
1607- ])
1614+ yielded_types = [v .type for v in jax .tree .leaves (_yielded_values (outs , ctx .avals_out ))]
16081615
16091616 switch_op = scf_dialect .IndexSwitchOp (
16101617 yielded_types ,
@@ -1626,11 +1633,7 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
16261633 ctx .module_ctx , ctx .launch_ctx , branch .jaxpr , args , consts = branch .consts
16271634 )
16281635
1629- yielded = [
1630- _ensure_ir_value (out , aval .dtype ) or out
1631- for out , aval in zip (outs , ctx .avals_out )
1632- ]
1633- yielded_leaves , yielded_treedef = jax .tree .flatten (yielded )
1636+ yielded_leaves , yielded_treedef = jax .tree .flatten (_yielded_values (outs , ctx .avals_out ))
16341637 if treedef is None :
16351638 treedef = yielded_treedef
16361639 else :
0 commit comments