-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][linalg] Fix incorrect linalg short form printing #153219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-linalg Author: Boyana Norris (brnorris03) ChangesBoth Fixes #117528 Full diff: https://github.com/llvm/llvm-project/pull/153219.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32) {
+ %0 = arith.addf %in1, %in2: f32
+ linalg.yield %in1: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
+// CHECK-SAME: dimensions = [1]
+// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
// -----
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// -----
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+ %res = tensor.empty() : tensor<1x32xf32>
+ %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+ (%in_1: f32, %in_2: f32) {
+ %1 = arith.maximumf %in_1, %in_2 : f32
+ linalg.yield %in_1 : f32
+ }
+ func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
+// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
+// -----
+
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
|
|
@llvm/pr-subscribers-mlir Author: Boyana Norris (brnorris03) ChangesBoth Fixes #117528 Full diff: https://github.com/llvm/llvm-project/pull/153219.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9d7fb18f56fef..7af4ea6a2f3a4 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1570,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1622,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1829,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 4edbc6eda3eae..a09348c69d3a3 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -436,6 +436,34 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
// CHECK-SAME: outs
// CHECK-SAME: dimensions = [1]
+
+// -----
+
+
+func.func @reduce_not_short_form_compatible(%input: tensor<16x32x64xf32>,
+ %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+ %reduce = linalg.reduce
+ ins(%input:tensor<16x32x64xf32>)
+ outs(%init:tensor<16x64xf32>)
+ dimensions = [1]
+ (%in1: f32, %in2: f32) {
+ %0 = arith.addf %in1, %in2: f32
+ linalg.yield %in1: f32
+ }
+ func.return %reduce : tensor<16x64xf32>
+}
+
+// CHECK-LABEL: func @reduce_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<16x32x64xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<16x64xf32>
+// CHECK-NOT: linalg.reduce { arith.addf } ins(%[[ARG0]] : tensor<16x32x64xf32>
+// CHECK: linalg.reduce ins(%[[ARG0]] : tensor<16x32x64xf32>) outs(%[[ARG1]] : tensor<16x64xf32>)
+// CHECK-SAME: dimensions = [1]
+// CHECK: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[ADD_RESULT:.*]] = arith.addf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
// -----
func.func @reduce_memref(%input: memref<16x32x64xf32>,
@@ -592,6 +620,29 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
// -----
+func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> {
+ %res = tensor.empty() : tensor<1x32xf32>
+ %mapped = linalg.map ins(%arg0, %arg1 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%res : tensor<1x32xf32>)
+ (%in_1: f32, %in_2: f32) {
+ %1 = arith.maximumf %in_1, %in_2 : f32
+ linalg.yield %in_1 : f32
+ }
+ func.return %mapped : tensor<1x32xf32>
+}
+
+// CHECK-LABEL: func @map_not_short_form_compatible
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32xf32>, %[[ARG1:.*]]: tensor<1x32xf32>
+// CHECK: %[[RES:.*]] = tensor.empty() : tensor<1x32xf32>
+// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[ARG0]] : tensor<1x32xf32>
+// CHECK: linalg.map ins(%[[ARG0]], %[[ARG1]] : tensor<1x32xf32>, tensor<1x32xf32>)
+// CHECK-SAME: outs(%[[RES]] : tensor<1x32xf32>)
+// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
+// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
+// CHECK-NEXT: linalg.yield %[[IN1]] : f32
+// CHECK-NEXT: }
+
+// -----
+
func.func @reduce_arith_with_attr(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
%reduce = linalg.reduce
|
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks!
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you! LGTM % nit
Btw, it would be useful to update the docs to clarify what the short-form is:
Right now it is not super clear.
| func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> { | ||
| %res = tensor.empty() : tensor<1x32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit 1] Use meaningful and consistent names for function arguments.
[nit 2] Keep the number of ops to the required minimum.
| func.func @map_not_short_form_compatible(%arg0: tensor<1x32xf32>, %arg1: tensor<1x32xf32>) -> tensor<1x32xf32> { | |
| %res = tensor.empty() : tensor<1x32xf32> | |
| func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %res: tensor<1x32xf32>) -> tensor<1x32xf32> { | |
| %res = tensor.empty() : tensor<1x32xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, thanks for the suggestion!
|
Thanks again for fixing this! Do you have commit access? If not, I can land this for you. |
|
@banach-space you can know by looking on the top-right of PR next to the contributor name, people with commit access will display "member" and otherwise it says "contributor" (or "first-time contributor"). |
I don't have commit access (yet -- will request it when I meet the criteria), so if you can merge, that would be great, thank you! |
Both
linalg.mapandlinalg.reduceare sometimes printed in short form incorrectly, resulting in a round-trip output with different semantics. This patch adds additionalyieldoperand checks to ensure that all criteria for short-form printing are satisfied. Updated/added comments and renamed thefindPayloadOpfunction tocanUseShortForm, which more accurately reflects its purpose. A couple of new lit tests check for the proper use of long form when short-form conditions are not met.Fixes #117528