Skip to content

Commit d531704

Browse files
authored
Add verifier for migraphx.reshape (#1999)
* Add verifier for MIGraphX Reshape op * Update for values of -1 * Clang-format * Attend to review comments * Make more changes based on review comments * Fix outDim handling
1 parent 8cce1d3 commit d531704

File tree

4 files changed

+109
-3
lines changed

4 files changed

+109
-3
lines changed

mlir/include/mlir/Dialect/MIGraphX/IR/MIGraphX.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,18 +439,25 @@ def MIGraphX_TransposeOp :
439439
}
440440

441441
def MIGraphX_ReshapeOp :
442-
MIGraphX_Op<"reshape">,
442+
MIGraphX_Op<"reshape", [AllElementTypesMatch<["input", "output"]>]>,
443443
Arguments<(ins AnyMIXRShaped:$input,
444444
I64ArrayAttr:$dims
445445
)>,
446446
Results<(outs AnyMIXRShaped:$output)> {
447447
let summary = "reshape a tensor";
448448
let description = [{
449449
The `migraphx.reshape` op.
450+
451+
`dims` can contain values of -1 and 0. A value of -1 means to infer this
452+
dimension from the others. i.e., If we have an output shape of
453+
`<4x2xf32, 0x1>` and the dims are `[4, -1]`, then the second dimension will
454+
be inferred to be 2. A value of 0 means explicit zero-length dimension.
450455
}];
451456
let assemblyFormat = [{
452457
$input attr-dict `:` type($input) `->` type($output)
453458
}];
459+
460+
let hasVerifier = 1;
454461
}
455462

456463
def MIGraphX_SliceOp :

mlir/lib/Dialect/MIGraphX/IR/MIGraphX.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,61 @@ LogicalResult LiteralOp::verify() {
313313
return success();
314314
}
315315

316+
LogicalResult ReshapeOp::verify() {
317+
MIXRShapedType inputType = getInput().getType();
318+
MIXRShapedType outType = getOutput().getType();
319+
ArrayAttr dimsAttr = getDims();
320+
321+
// Dynamic shapes are not currently supported
322+
if (!inputType.hasStaticShape())
323+
return emitOpError("Dynamic shapes are not supported");
324+
325+
if (dimsAttr.size() != outType.getRank())
326+
return emitOpError("number of dims (")
327+
<< dimsAttr.size() << ") does not match result rank ("
328+
<< outType.getRank() << ")";
329+
330+
// Check that there is only a single -1 value
331+
int missingDims = llvm::count_if(
332+
dimsAttr.getAsRange<IntegerAttr>(),
333+
[](IntegerAttr a) { return a.getInt() == -1; });
334+
if (missingDims > 1)
335+
return emitOpError("expected at most one target dimension to be -1");
336+
337+
// Check how many zero dimensions there are
338+
int numZeros = llvm::count_if(
339+
dimsAttr.getAsRange<IntegerAttr>(),
340+
[](IntegerAttr a) { return a.getInt() == 0; });
341+
342+
if (missingDims > 0 && numZeros > 0)
343+
return emitOpError("Cannot mix missing dimensions with zero dimension");
344+
345+
// Compare dimension values to output shape
346+
for (auto [dimVal, outDim] : llvm::zip(dimsAttr, outType.getShape())) {
347+
int64_t dimValue = cast<IntegerAttr>(dimVal).getInt();
348+
// We cannot handle negative dims values that aren't -1
349+
if (dimValue < -1 ) {
350+
return emitOpError("Non -1 negative values are not supported");
351+
}
352+
353+
// Output dimensions can't be negative
354+
if (outDim < 0)
355+
return emitOpError("Negative output dimensions are not supported");
356+
357+
// Per-dimension consistency
358+
if (dimValue >= 0 && outDim != dimValue)
359+
return emitOpError("dimValue: ")
360+
<< dimValue << " inconsistent with result dimension " << outDim;
361+
}
362+
363+
// Check that the number of elements in the input and output types match
364+
int64_t inputElements = inputType.getNumElements();
365+
if (inputElements != outType.getNumElements())
366+
return emitOpError("input and output element counts do not match");
367+
368+
return success();
369+
}
370+
316371
LogicalResult UnpackOp::verify() {
317372
MIXRShapedType inType = getIn().getType();
318373
MIXRShapedType outType = getOut().getType();
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: rocmlir-opt %s -split-input-file -verify-diagnostics
2+
3+
func.func @mlir_reshape_inconsistent_dims(%arg0: !migraphx.shaped<4096x4096xf16, 0x1>) {
4+
// expected-error@+1 {{'migraphx.reshape' op dimValue: 64 inconsistent with result dimension 4096}}
5+
%0 = migraphx.reshape %arg0 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
6+
return
7+
}
8+
9+
func.func @mlir_reshape_dynamic_shape(%arg0: !migraphx.shaped<4096x?xf16, 0x1>) {
10+
// expected-error@+1 {{'migraphx.reshape' op Dynamic shapes are not supported}}
11+
%0 = migraphx.reshape %arg0 {dims = [4096, 4096]} : <4096x?xf16, 0x1> -> <4096x?xf16, 16536x2>
12+
return
13+
}
14+
15+
func.func @mlir_reshape_rank(%arg0: !migraphx.shaped<4096x4096xf16, 0x1>) {
16+
// expected-error@+1 {{'migraphx.reshape' op number of dims (3) does not match result rank (2)}}
17+
%0 = migraphx.reshape %arg0 {dims = [1, 4096, 4096]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
18+
return
19+
}
20+
21+
func.func @mlir_num_input_elements(%arg0: !migraphx.shaped<2x4xf16, 0x1>) {
22+
// expected-error@+1 {{'migraphx.reshape' op input and output element counts do not match}}
23+
%0 = migraphx.reshape %arg0 {dims = [3, 5]} : <2x4xf16, 0x1> -> <3x5xf16, 0x1>
24+
return
25+
}
26+
27+
func.func @mlir_element_type(%arg0: !migraphx.shaped<2x4xf16, 0x1>) {
28+
// expected-error@+1 {{'migraphx.reshape' op failed to verify that all of {input, output} have same element type}}
29+
%0 = migraphx.reshape %arg0 {dims = [4, 2]} : <2x4xf16, 0x1> -> <4x2xf32, 0x1>
30+
return
31+
}
32+
33+
func.func @mlir_multiple_neg_one(%arg0: !migraphx.shaped<2x4xf16, 0x1>) {
34+
// expected-error@+1 {{'migraphx.reshape' op expected at most one target dimension to be -1}}
35+
%0 = migraphx.reshape %arg0 {dims = [-1, -1]} : <2x4xf16, 0x1> -> <4x2xf16, 0x1>
36+
return
37+
}
38+
39+
func.func @mlir_neg_one_with_zero(%arg0: !migraphx.shaped<2x4xf16, 0x1>) {
40+
// expected-error@+1 {{'migraphx.reshape' op Cannot mix missing dimensions with zero dimension}}
41+
%0 = migraphx.reshape %arg0 {dims = [0, -1]} : <2x4xf16, 0x1> -> <4x2xf16, 0x1>
42+
return
43+
}
44+

mlir/test/fusion/mixr-attention-realtype-int4-dequantizelinear.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ module {
1515
%0 = migraphx.unpack %arg3 {axis = 1 : i64} : <4096x2048xui8, 2048x1> -> <4096x4096xi8, 4096x1>
1616
%1 = migraphx.broadcast %arg1 {axis = 0 : i64, out_lens = [4096, 4096]} : <4096xf16, 1> -> <4096x4096xf16, 0x1>
1717
%2 = migraphx.broadcast %arg2 {axis = 0 : i64, out_lens = [4096, 4096]} : <4096xf16, 1> -> <4096x4096xf16, 0x1>
18-
%3 = migraphx.reshape %1 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
19-
%4 = migraphx.reshape %2 {dims = [64, 128]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
18+
%3 = migraphx.reshape %1 {dims = [4096, 4096]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
19+
%4 = migraphx.reshape %2 {dims = [4096, 4096]} : <4096x4096xf16, 0x1> -> <4096x4096xf16, 16536x2>
2020
%5 = migraphx.dequantizelinear %0, %3, %4 : <4096x4096xi8, 4096x1>, <4096x4096xf16, 16536x2>, !migraphx.shaped<4096x4096xf16, 16536x2> -> <4096x4096xf16, 4096x1>
2121
%6 = migraphx.dot %2, %5 : <4096x4096xf16, 0x1>, <4096x4096xf16, 4096x1> -> <4096x4096xf16, 4096x1>
2222
%7 = migraphx.softmax %6 {axis = 1 : i64} : <4096x4096xf16, 4096x1> -> <4096x4096xf16, 4096x1>

0 commit comments

Comments
 (0)