Skip to content

Commit b1423a3

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Fix a use-after-free in lowering
The lifetime of values is bound to the ops that produce them, which are deleted after the `with` block. The lifetime of types is bound to the context. PiperOrigin-RevId: 701997797
1 parent 8a31619 commit b1423a3

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,14 +1538,13 @@ def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches):
15381538
outs = lower_jaxpr_to_mosaic_gpu(
15391539
ctx.module_ctx, ctx.launch_ctx, branches[0].jaxpr, args
15401540
)
1541-
yielded = [
1542-
_ensure_ir_value(out, aval.dtype) or out
1541+
yielded_types, _ = jax.tree.flatten([
1542+
(_ensure_ir_value(out, aval.dtype) or out).type
15431543
for out, aval in zip(outs, ctx.avals_out)
1544-
]
1545-
yielded_leaves, _ = jax.tree.flatten(yielded)
1544+
])
15461545

15471546
switch_op = scf_dialect.IndexSwitchOp(
1548-
[v.type for v in yielded_leaves],
1547+
yielded_types,
15491548
_as_index(_ensure_ir_value(index, index_aval.dtype)),
15501549
ir.DenseI64ArrayAttr.get(range(len(branches) - 1)),
15511550
num_caseRegions=len(branches) - 1,

0 commit comments

Comments
 (0)