Skip to content

Conversation

@brnorris03
Copy link
Contributor

@brnorris03 brnorris03 commented Aug 12, 2025

Both linalg.map and linalg.reduce are sometimes printed in short form incorrectly, resulting in a round-trip output with different semantics. This patch adds additional yield operand checks to ensure that all criteria for short-form printing are satisfied. Updated/added comments and renamed the findPayloadOp function to canUseShortForm, which more accurately reflects its purpose. A couple of new lit tests check for the proper use of long form when short-form conditions are not met.

Fixes #117528

@brnorris03 brnorris03 marked this pull request as ready for review August 12, 2025 19:59
@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Boyana Norris (brnorris03)

Changes

Both linalg.map and linalg.reduce are printed in short form incorrectly, resulting in a round-trip output with different semantics. This patch adds additional yield operand checks to ensure that all criteria for short-form printing are satisfied. Updated/added comments and renamed the findPayloadOp function to canUseShortForm, which more accurately reflects its purpose. A couple of new lit tests check for the proper use of long form when short-form conditions are not met.

Fixes #117528


Full diff: https://github.com/llvm/llvm-project/pull/153219.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31-21)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+51)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+  // Check if the body can be printed in short form. The following 4 conditions
+  // must be satisfied:
+
+  // 1) The body must contain exactly 2 operations: the payload op and a yield.
   if (body->getOperations().size() != 2)
-    return nullptr;
+    return false;
   Operation &payload = body->getOperations().front();
-  assert(isa<YieldOp>(body->getOperations().back()));
 
+  // 2) The payload op must have the same number of operands as the number of
+  //    block arguments.
   if (payload.getNumOperands() == 0 ||
       payload.getNumOperands() != body->getNumArguments())
-    return nullptr;
+    return false;
+
+  // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+  //    must be the first operand of the payload op, otherwise, the operands
+  //    must match the block arguments in order.
   if (initFirst) {
     // check init
     if (payload.getOperands().back() != body->getArgument(0))
-      return nullptr;
+      return false;
     // check rest
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   } else {
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   }
-  return &payload;
+
+  // 4) The `yield` operand must be the result of the payload op.
+  auto yieldOp = cast<YieldOp>(body->getTerminator());
+  return yieldOp.getNumOperands() == 1 &&
+         yieldOp.getOperand(0).getDefiningOp() &&
+         yieldOp.getOperand(0).getDefiningOp() == &payload;
 }
 
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
   SmallVector<StringRef> elidedAttrs;
   std::string attrToElide;
   p << " { " << payloadOp->getName().getStringRef();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   p.printOptionalAttrDict((*this)->getAttrs());
 
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 
 void ReduceOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 //  CHECK-SAME:   outs
 //  CHECK-SAME:   dimensions = [1]
 
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+                             %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %reduce = linalg.reduce
+    ins(%input:tensor<16x32x64xf32>)
+    outs(%init:tensor<16x64xf32>)
+    dimensions = [1]
+    (%in1: f32, %in2: f32) {
+      %0 = arith.addf %in1, %in2: f32
+      linalg.yield %in1: f32
+    }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME:  %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK:     linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>) 
+// CHECK-SAME:  dimensions = [1]
+// CHECK:       (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:    %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:    linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:  }
+
 // -----
 
 func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 // -----
 
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+  %res = tensor.empty() : tensor<1x32xf32>
+  %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+    (%in_1: f32, %in_2: f32) {
+      %1 = arith.maximumf %in_1, %in_2 : f32
+      linalg.yield %in_1 : f32
+    }
+  func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME:        %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK:         %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK:         linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>) 
+// CHECK-SAME:               outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:        linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:    }
+
+// -----
+
 func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   %reduce = linalg.reduce

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2025

@llvm/pr-subscribers-mlir

Author: Boyana Norris (brnorris03)

Changes

Both linalg.map and linalg.reduce are printed in short form incorrectly, resulting in a round-trip output with different semantics. This patch adds additional yield operand checks to ensure that all criteria for short-form printing are satisfied. Updated/added comments and renamed the findPayloadOp function to canUseShortForm, which more accurately reflects its purpose. A couple of new lit tests check for the proper use of long form when short-form conditions are not met.

Fixes #117528


Full diff: https://github.com/llvm/llvm-project/pull/153219.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+31-21)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+51)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+  // Check if the body can be printed in short form. The following 4 conditions
+  // must be satisfied:
+
+  // 1) The body must contain exactly 2 operations: the payload op and a yield.
   if (body->getOperations().size() != 2)
-    return nullptr;
+    return false;
   Operation &payload = body->getOperations().front();
-  assert(isa<YieldOp>(body->getOperations().back()));
 
+  // 2) The payload op must have the same number of operands as the number of
+  //    block arguments.
   if (payload.getNumOperands() == 0 ||
       payload.getNumOperands() != body->getNumArguments())
-    return nullptr;
+    return false;
+
+  // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+  //    must be the first operand of the payload op, otherwise, the operands
+  //    must match the block arguments in order.
   if (initFirst) {
     // check init
     if (payload.getOperands().back() != body->getArgument(0))
-      return nullptr;
+      return false;
     // check rest
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   } else {
     for (const auto &[operand, bbArg] :
          llvm::zip(payload.getOperands(), body->getArguments())) {
       if (bbArg != operand)
-        return nullptr;
+        return false;
     }
   }
-  return &payload;
+
+  // 4) The `yield` operand must be the result of the payload op.
+  auto yieldOp = cast<YieldOp>(body->getTerminator());
+  return yieldOp.getNumOperands() == 1 &&
+         yieldOp.getOperand(0).getDefiningOp() &&
+         yieldOp.getOperand(0).getDefiningOp() == &payload;
 }
 
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
   SmallVector<StringRef> elidedAttrs;
   std::string attrToElide;
   p << " { " << payloadOp->getName().getStringRef();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
 
 void MapOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   p.printOptionalAttrDict((*this)->getAttrs());
 
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
 
 void ReduceOp::print(OpAsmPrinter &p) {
   Block *mapper = getBody();
-  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
-  if (payloadOp) {
-    printShortForm(p, payloadOp);
+  bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+  if (useShortForm) {
+    printShortForm(p, &mapper->getOperations().front());
   }
 
   printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
   printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
   p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
-  if (!payloadOp) {
+  if (!useShortForm) {
     // Print region if the payload op was not detected.
     p.increaseIndent();
     p.printNewline();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 //  CHECK-SAME:   outs
 //  CHECK-SAME:   dimensions = [1]
 
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+                             %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+  %reduce = linalg.reduce
+    ins(%input:tensor<16x32x64xf32>)
+    outs(%init:tensor<16x64xf32>)
+    dimensions = [1]
+    (%in1: f32, %in2: f32) {
+      %0 = arith.addf %in1, %in2: f32
+      linalg.yield %in1: f32
+    }
+  func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME:  %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME:  %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK:     linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>) 
+// CHECK-SAME:  dimensions = [1]
+// CHECK:       (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:    %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:    linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:  }
+
 // -----
 
 func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
 
 // -----
 
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+  %res = tensor.empty() : tensor<1x32xf32>
+  %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+    (%in_1: f32, %in_2: f32) {
+      %1 = arith.maximumf %in_1, %in_2 : f32
+      linalg.yield %in_1 : f32
+    }
+  func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME:        %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK:         %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT:     linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK:         linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>) 
+// CHECK-SAME:               outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT:      (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT:        %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT:        linalg.yield %[[IN1]] : f32
+// CHECK-NEXT:    }
+
+// -----
+
 func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
                   %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   %reduce = linalg.reduce

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you! LGTM % nit

Btw, it would be useful to update the docs to clarify what the short-form is:

Right now it is not super clear.

Comment on lines 623 to 624
func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit 1] Use meaningful and consistent names for function arguments.
[nit 2] Keep the number of ops to the required minimum.

Suggested change
func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %res: tensor<1x32xf32>) -> tensor<1x32xf32> {
%res = tensor.empty() : tensor<1x32xf32>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks for the suggestion!

@banach-space
Copy link
Contributor

Thanks again for fixing this!

Do you have commit access? If not, I can land this for you.

@joker-eph
Copy link
Collaborator

@banach-space you can know by looking on the top-right of PR next to the contributor name, people with commit access will display "member" and otherwise it says "contributor" (or "first-time contributor").

@brnorris03
Copy link
Contributor Author

Thanks again for fixing this!

Do you have commit access? If not, I can land this for you.

I don't have commit access (yet -- will request it when I meet the criteria), so if you can merge, that would be great, thank you!

@rengolin rengolin merged commit 1945753 into llvm:main Aug 14, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MLIR Linalg short form format is used when it should not be

5 participants