@@ -1514,8 +1514,24 @@ def _while_lowering_rule(
15141514@register_lowering_rule (lax .cond_p )
15151515def _cond_lowering_rule (ctx : LoweringRuleContext , index , * args , branches ):
15161516 index_aval , * _arg_avals = ctx .avals_in
1517+
1518+ # We need the branch return mlir types in order to construct the
1519+ # switch operation. To avoid leaking information about what kind of
1520+ # mlir types are internal to FragmentedArrays and other mgpu types,
1521+ # we run one of the branches in a dummy module that we throw away to
1522+ # extract the return types
1523+ with ir .InsertionPoint (ir .Module .create ().body ):
1524+ outs = lower_jaxpr_to_mosaic_gpu (
1525+ ctx .module_ctx , ctx .launch_ctx , branches [0 ].jaxpr , args
1526+ )
1527+ yielded = [
1528+ _ensure_ir_value (out , aval .dtype ) or out
1529+ for out , aval in zip (outs , ctx .avals_out )
1530+ ]
1531+ yielded_leaves , _ = jax .tree .flatten (yielded )
1532+
15171533 switch_op = scf_dialect .IndexSwitchOp (
1518- map ( mgpu_utils . dtype_to_ir_type , ctx . avals_out ) ,
1534+ [ v . type for v in yielded_leaves ] ,
15191535 _as_index (_ensure_ir_value (index , index_aval .dtype )),
15201536 ir .DenseI64ArrayAttr .get (range (len (branches ) - 1 )),
15211537 num_caseRegions = len (branches ) - 1 ,
@@ -1527,16 +1543,27 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15271543 regions = list (switch_op .regions )
15281544 # Move the default region to the back.
15291545 regions = regions [1 :] + regions [:1 ]
1546+ treedef = None
15301547 for branch , region in zip (branches , regions ):
15311548 with ir .InsertionPoint (region .blocks .append ()):
15321549 outs = lower_jaxpr_to_mosaic_gpu (
1533- ctx .module_ctx , ctx .launch_ctx , branch .jaxpr , args
1550+ ctx .module_ctx , ctx .launch_ctx , branch .jaxpr , args , consts = branch . consts
15341551 )
1535- scf_dialect .yield_ ([
1536- _ensure_ir_value (out , aval .dtype )
1552+
1553+ yielded = [
1554+ _ensure_ir_value (out , aval .dtype ) or out
15371555 for out , aval in zip (outs , ctx .avals_out )
1538- ])
1539- return list (switch_op .results )
1556+ ]
1557+ yielded_leaves , yielded_treedef = jax .tree .flatten (yielded )
1558+ if treedef is None :
1559+ treedef = yielded_treedef
1560+ else :
1561+ assert treedef == yielded_treedef
1562+
1563+ scf_dialect .yield_ (yielded_leaves )
1564+
1565+ assert treedef is not None
1566+ return treedef .unflatten (list (switch_op .results ))
15401567
15411568
15421569@register_lowering_rule (lax .bitcast_convert_type_p )
0 commit comments