Commit 5c097c8
#sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 7133145551 parent 0389d61 commit 5c097c8
1 file changed
+6
-4
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1205 | 1205 | | |
1206 | 1206 | | |
1207 | 1207 | | |
1208 | | - | |
1209 | | - | |
1210 | | - | |
1211 | | - | |
1212 | 1208 | | |
1213 | 1209 | | |
1214 | 1210 | | |
| |||
1227 | 1223 | | |
1228 | 1224 | | |
1229 | 1225 | | |
| 1226 | + | |
| 1227 | + | |
| 1228 | + | |
| 1229 | + | |
| 1230 | + | |
| 1231 | + | |
1230 | 1232 | | |
1231 | 1233 | | |
1232 | 1234 | | |
| |||
0 commit comments