diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index e32d3d01bb182..f3674c3eecfe6 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -253,10 +253,12 @@ def MapOp : LinalgStructuredBase_Op<"map", [ } ``` - Shortened print form is available. Applies to simple maps with one - non-yield operation inside the body. + Shortened print form is available for simple maps where the body contains exactly + two operations (the payload operation and a yield), the payload operation has + the same number of operands as block arguments with operands matching block + arguments in order, and the yield operand is the result of the payload operation. - The example above will be printed as: + The example above will be printed using the shortened form as: ```mlir %add = linalg.map { arith.addf } ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>) @@ -340,13 +342,15 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [ } ``` - Shortened print form is available. Applies to simple (not variadic) reduces - with one non-yield operation inside the body. Applies only if the operation - takes `%out` as the first argument. + Shortened print form is available for simple reduces where the body contains exactly + two operations (the payload operation and a yield), the payload operation has the + same number of operands as block arguments, the first block argument (init) is the + last operand of the payload operation with remaining operands matching remaining + block arguments in order, and the yield operand is the result of the payload operation. - The example above will be printed as: + The example above will be printed using the shortened form as: ```mlir - %reduce = linalg.reduce { arith.addf } + %reduce = linalg.reduce { arith.addf } ins(%input:tensor<16x32x64xf32>) outs(%init:tensor<16x64xf32>) dimensions = [1] diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 9d7fb18f56fef..7af4ea6a2f3a4 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -// Retrieve the operation from the body, if it is the only one (except -// yield) and if it gets the same amount of arguments as the body does. -// If initFirst flag is enabled, we check that init takes the first position in -// operands of payload. -static Operation *findPayloadOp(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false) { + // Check if the body can be printed in short form. The following 4 conditions + // must be satisfied: + + // 1) The body must contain exactly 2 operations: the payload op and a yield. if (body->getOperations().size() != 2) - return nullptr; + return false; Operation &payload = body->getOperations().front(); - assert(isa(body->getOperations().back())); + // 2) The payload op must have the same number of operands as the number of + // block arguments. if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) - return nullptr; + return false; + + // 3) If `initFirst` is true (e.g., for reduction ops), the init block + // must be the first operand of the payload op, otherwise, the operands + // must match the block arguments in order. if (initFirst) { // check init if (payload.getOperands().back() != body->getArgument(0)) - return nullptr; + return false; // check rest for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { if (bbArg != operand) - return nullptr; + return false; } } else { for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments())) { if (bbArg != operand) - return nullptr; + return false; } } - return &payload; + + // 4) The `yield` operand must be the result of the payload op. + auto yieldOp = cast(body->getTerminator()); + return yieldOp.getNumOperands() == 1 && + yieldOp.getOperand(0).getDefiningOp() && + yieldOp.getOperand(0).getDefiningOp() == &payload; } -void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { +static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector elidedAttrs; std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); @@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); p.printOptionalAttrDict((*this)->getAttrs()); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index 4edbc6eda3eae..563013d4083af 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>, // CHECK-SAME: outs // CHECK-SAME: dimensions = [1] + +// ----- + + +func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>, + %init: tensor<16x64xf32>) -> tensor<16x64xf32> { + %reduce = linalg.reduce + ins(%input:tensor<16x32x64xf32>) + outs(%init:tensor<16x64xf32>) + dimensions = [1] + (%in1: f32, %in2: f32) { + %0 = arith.addf %in1, %in2: f32 + linalg.yield %in1: f32 + } + func.return %reduce : tensor<16x64xf32> +} + +// CHECK-LABEL: func @reduce_not_short_form_compatible +// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32x64xf32> +// CHECK-SAME: %[[INIT:.*]]: tensor<16x64xf32> +// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[INPUT]] : tensor<16x32x64xf32> +// CHECK: linalg.reduce ins(%[[INPUT]] : tensor<16x32x64xf32>) outs(%[[INIT]] : tensor<16x64xf32>) +// CHECK-SAME: dimensions = [1] +// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) { +// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32 +// CHECK-NEXT: linalg.yield %[[IN1]] : f32 +// CHECK-NEXT: } + // ----- func.func @reduce_memref(%input: memref<16x32x64xf32>, @@ -592,6 +620,27 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, // ----- +func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> { + %mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>) + (%in_1: f32, %in_2: f32) { + %1 = arith.maximumf %in_1, %in_2 : f32 + linalg.yield %in_1 : f32 + } + func.return %mapped : tensor<1x32xf32> +} + +// CHECK-LABEL: func @map_not_short_form_compatible +// CHECK-SAME: %[[LHS:.*]]: tensor<1x32xf32>, %[[RHS:.*]]: tensor<1x32xf32>, %[[INIT:.*]]: tensor<1x32xf32> +// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32> +// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>) +// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>) +// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) { +// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32 +// CHECK-NEXT: linalg.yield %[[IN1]] : f32 +// CHECK-NEXT: } + +// ----- + func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>, %init: tensor<16x64xf32>) -> tensor<16x64xf32> { %reduce = linalg.reduce