Skip to content

Commit 9e1383f

Browse files
committed
[mlir] Fix #93973 - linalg::ReduceOp verifier crash
1 parent 98204a2 commit 9e1383f

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [
331331
def ReduceOp : LinalgStructuredBase_Op<"reduce", [
332332
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
333333
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
334-
SameVariadicOperandSize,
334+
AttrSizedOperandSegments,
335335
SingleBlockImplicitTerminator<"YieldOp">]> {
336336
let summary = "Reduce operator";
337337
let description = [{

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,11 +1339,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
13391339
static ParseResult parseDstStyleOp(
13401340
OpAsmParser &parser, OperationState &result,
13411341
function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1342-
nullptr) {
1342+
nullptr,
1343+
bool addOperandSegmentSizes = false) {
13431344
// Parse `ins` and `outs`.
13441345
SmallVector<Type, 4> inputTypes, outputTypes;
13451346
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1346-
/*addOperandSegmentSizes=*/false))
1347+
addOperandSegmentSizes))
13471348
return failure();
13481349

13491350
// Add result types.
@@ -1694,9 +1695,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
16941695
}
16951696

16961697
if (parseDstStyleOp(
1697-
parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1698+
parser, result,
1699+
[&](OpAsmParser &parser, NamedAttrList &attributes) {
16981700
return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1699-
}))
1701+
},
1702+
/*addOperandSegmentSizes=*/true))
1703+
17001704
return failure();
17011705

17021706
if (payloadOpName.has_value()) {
@@ -1731,7 +1735,9 @@ void ReduceOp::print(OpAsmPrinter &p) {
17311735

17321736
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
17331737
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1734-
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1738+
p.printOptionalAttrDict(
1739+
(*this)->getAttrs(),
1740+
{getDimensionsAttrName(), getOperandSegmentSizesAttrName()});
17351741
if (!payloadOp) {
17361742
// Print region if the payload op was not detected.
17371743
p.increaseIndent();

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
485485

486486
// -----
487487

488+
func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>,
489+
%init: tensor<16x64xi32>) -> tensor<16x64xi32> {
490+
%reduce = linalg.reduce
491+
ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>)
492+
outs(%init:tensor<16x64xi32>)
493+
dimensions = [1]
494+
(%in: i32, %in2: i32, %out: i32) {
495+
%0 = arith.muli %in, %in2: i32
496+
%1 = arith.addi %out, %0: i32
497+
linalg.yield %1: i32
498+
}
499+
func.return %reduce : tensor<16x64xi32>
500+
}
501+
// CHECK-LABEL: func @reduce_asymmetric
502+
// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>)
503+
// CHECK-NOT: operandSegmentSize
504+
// CHECK-SAME: outs(%{{.*}}: tensor<16x64xi32>)
505+
// CHECK-SAME: dimensions = [1]
506+
507+
// -----
508+
509+
func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>,
510+
%init: memref<16x64xi32>) {
511+
linalg.reduce
512+
ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>)
513+
outs(%init:memref<16x64xi32>)
514+
dimensions = [1]
515+
(%in: i32, %in2: i32, %out: i32) {
516+
%0 = arith.muli %in, %in2: i32
517+
%1 = arith.addi %out, %0: i32
518+
linalg.yield %1: i32
519+
}
520+
func.return
521+
}
522+
// CHECK-LABEL: func @reduce_asymmetric_memref
523+
// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>)
524+
// CHECK-NOT: operandSegmentSize
525+
// CHECK-SAME: outs(%{{.*}}: memref<16x64xi32>)
526+
// CHECK-SAME: dimensions = [1]
527+
528+
// -----
529+
488530
func.func @transpose(%input: tensor<16x32x64xf32>,
489531
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
490532
%transpose = linalg.transpose

0 commit comments

Comments
 (0)