diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 37eec6e07963b..461e2ed091fa4 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -331,7 +331,7 @@ def MapOp : LinalgStructuredBase_Op<"map", [ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - SameVariadicOperandSize, + AttrSizedOperandSegments, SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Reduce operator"; let description = [{ diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d9840e3923c4f..2fa1405ff8618 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1339,11 +1339,12 @@ LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { static ParseResult parseDstStyleOp( OpAsmParser &parser, OperationState &result, function_ref parseAttrsFn = - nullptr) { + nullptr, + bool addOperandSegmentSizes = false) { // Parse `ins` and `outs`. SmallVector inputTypes, outputTypes; if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, - /*addOperandSegmentSizes=*/false)) + addOperandSegmentSizes)) return failure(); // Add result types. @@ -1694,9 +1695,12 @@ ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { } if (parseDstStyleOp( - parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { + parser, result, + [&](OpAsmParser &parser, NamedAttrList &attributes) { return parseDenseI64ArrayAttr(parser, attributes, "dimensions"); - })) + }, + /*addOperandSegmentSizes=*/true)) + return failure(); if (payloadOpName.has_value()) { @@ -1731,7 +1735,9 @@ void ReduceOp::print(OpAsmPrinter &p) { printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); - p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); + p.printOptionalAttrDict( + (*this)->getAttrs(), + {getDimensionsAttrName(), getOperandSegmentSizesAttrName()}); if (!payloadOp) { // Print region if the payload op was not detected. p.increaseIndent(); diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 1b8969bd11559..d4ad7584d00d8 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -485,6 +485,48 @@ func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>, // ----- +func.func @reduce_asymmetric(%input: tensor<16x32x64xi32>, %input2: tensor<16x32x64xi32>, + %init: tensor<16x64xi32>) -> tensor<16x64xi32> { + %reduce = linalg.reduce + ins(%input, %input2:tensor<16x32x64xi32>, tensor<16x32x64xi32>) + outs(%init:tensor<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return %reduce : tensor<16x64xi32> +} +// CHECK-LABEL: func @reduce_asymmetric +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: tensor<16x32x64xi32>, tensor<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: tensor<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + +func.func @reduce_asymmetric_memref(%input: memref<16x32x64xi32>, %input2: memref<16x32x64xi32>, + %init: memref<16x64xi32>) { + linalg.reduce + ins(%input, %input2:memref<16x32x64xi32>, memref<16x32x64xi32>) + outs(%init:memref<16x64xi32>) + dimensions = [1] + (%in: i32, %in2: i32, %out: i32) { + %0 = arith.muli %in, %in2: i32 + %1 = arith.addi %out, %0: i32 + linalg.yield %1: i32 + } + func.return +} +// CHECK-LABEL: func @reduce_asymmetric_memref +// CHECK: linalg.reduce ins(%{{.*}}, %{{.*}}: memref<16x32x64xi32>, memref<16x32x64xi32>) +// CHECK-NOT: operandSegmentSize +// CHECK-SAME: outs(%{{.*}}: memref<16x64xi32>) +// CHECK-SAME: dimensions = [1] + +// ----- + func.func @transpose(%input: tensor<16x32x64xf32>, %init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> { %transpose = linalg.transpose