Skip to content

Commit 5c097c8

Browse files
bartchr808Google-ML-Automation
authored andcommitted
#sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example ``` File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module pipeline.run(ctx.module.operation) MLIRError: Failure while executing pass pipeline: error: ... 'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2 ... see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64> ``` PiperOrigin-RevId: 713314555
1 parent 0389d61 commit 5c097c8

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,10 +1205,6 @@ def lower_jaxpr_to_module(
12051205
arg_layouts=in_layouts,
12061206
result_layouts=out_layouts,
12071207
propagated_out_mem_kinds=propagated_out_mem_kinds)
1208-
if config.use_shardy_partitioner.value:
1209-
pipeline = passmanager.PassManager.parse(
1210-
'builtin.module(sdy-lift-inlined-meshes)')
1211-
pipeline.run(ctx.module.operation)
12121208

12131209
try:
12141210
if not ctx.module.operation.verify():
@@ -1227,6 +1223,12 @@ def emit_diagnostic_info(d):
12271223
raise ValueError("\n".join(msg_lines) + "\n" +
12281224
dump_module_message(ctx.module, "verification")) from e
12291225

1226+
if config.use_shardy_partitioner.value:
1227+
with ctx.context:
1228+
pipeline = passmanager.PassManager.parse(
1229+
'builtin.module(sdy-lift-inlined-meshes)')
1230+
pipeline.run(ctx.module.operation)
1231+
12301232
return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
12311233
ctx.shape_poly_state)
12321234

0 commit comments

Comments
 (0)