-
I used the following command to dump lmhlo modules from the example resnet50.py with version jaxlib-0.1.66.
In the generated module_0269.lmhlo in line: 5179 there was this output: %947 = memref.view %arg417[%c0_489][] : memref<1x1x512x2048xf32> to memref<512x2048x1x1xf32> with this error message when I tried to use an optimization on that file: error: 'memref.view' op operand #0 must be 1D memref of 8-bit signless integer values, but got 'memref<1x1x512x2048xf32>'
%947 = memref.view %arg417[%c0_489][] : memref<1x1x512x2048xf32> to memref<512x2048x1x1xf32>
^ In module_0269.lmhlo there were more cases where the first op operand was multi dimensional. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
I'm unsure, but this seems likely to be a version mismatch between your version of MLIR (the "optimization") and the version internal to XLA? Note that the The MLIR outputs are debug dumps, and it seems there was no misbehavior from JAX, so this isn't a bug. I'm going to convert this issue into a discussion. |
Beta Was this translation helpful? Give feedback.
-
I think this is a problem with XLA's internal MLIR representation. Maybe @timshen91 can take a look at it. |
Beta Was this translation helpful? Give feedback.
-
I have a fix. I'm about to submit it. |
Beta Was this translation helpful? Give feedback.
I have a fix. I'm about to submit it.