-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][linalg] Fix SemiFunctionType custom parsing crash on missing ()
#110365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `SemiFunctionType` allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory. This patch fixes a bug where omitting the parens around the argument types for a `SemiFunctionType` with non-optional result Types would crash the parser. It introduces a `bool` argument `resultOptional` to the parser and printer which, when `false`, correctly enforces the parens around argument types, otherwise printing an error. Fix llvm#109128
|
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir-linalg Author: Felix Schneider (ubfx) ChangesThe This patch fixes a bug where omitting the parens around the argument types for a Fix #109128 Full diff: https://github.com/llvm/llvm-project/pull/110365.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
let results = (outs TransformParamTypeInterface:$rank);
- let assemblyFormat =
- "$operand_handle attr-dict `:`"
- "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+ let assemblyFormat = [{
+ $operand_handle attr-dict `:`
+ custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..efbba1eb065dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
- `:` custom<SemiFunctionType>(type($target), type($transformed))
+ `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;
@@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
OptionalAttr<I64Attr>:$alignment);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
- let assemblyFormat =
- "$target attr-dict-with-keyword regions `:` "
- "custom<SemiFunctionType>(type($target), type($replacement))";
+ let assemblyFormat = [{
+ $target attr-dict-with-keyword regions `:`
+ custom<SemiFunctionType>(type($target), type($replacement), "false")
+ }];
let hasVerifier = 1;
}
@@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
/// the argument type in absence of result types, and does not accept the
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType);
+ Type &resultType, bool resultOptional = true);
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes);
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType);
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType);
+ Type argumentType, Type resultType,
+ bool resultOptional = true);
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
::mlir::Value getOperandHandle() { return getTarget(); }
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
using namespace mlir;
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType) {
+ Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
- bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+ bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+ : parser.parseLParen().succeeded();
+ if (!resultOptional && !hasLParen)
+ return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType) {
+ Type argumentType, Type resultType,
+ bool resultOptional) {
+ assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e86d4962530a9a..4bbd9bfd1443f4 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error@below {{expected '('}}
+ %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
+}
|
|
@llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) ChangesThe This patch fixes a bug where omitting the parens around the argument types for a Fix #109128 Full diff: https://github.com/llvm/llvm-project/pull/110365.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle);
let results = (outs TransformParamTypeInterface:$rank);
- let assemblyFormat =
- "$operand_handle attr-dict `:`"
- "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+ let assemblyFormat = [{
+ $operand_handle attr-dict `:`
+ custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..efbba1eb065dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:` "
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
- `:` custom<SemiFunctionType>(type($target), type($transformed))
+ `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;
@@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
OptionalAttr<I64Attr>:$alignment);
let results = (outs TransformHandleTypeInterface:$transformed);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($transformed))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($transformed), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
- let assemblyFormat =
- "$target attr-dict-with-keyword regions `:` "
- "custom<SemiFunctionType>(type($target), type($replacement))";
+ let assemblyFormat = [{
+ $target attr-dict-with-keyword regions `:`
+ custom<SemiFunctionType>(type($target), type($replacement), "false")
+ }];
let hasVerifier = 1;
}
@@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat =
- "$target attr-dict `:`"
- "custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
/// the argument type in absence of result types, and does not accept the
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType);
+ Type &resultType, bool resultOptional = true);
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes);
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType);
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType);
+ Type argumentType, Type resultType,
+ bool resultOptional = true);
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);
- let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+ let assemblyFormat = [{
+ $target attr-dict `:`
+ custom<SemiFunctionType>(type($target), type($result), "false")
+ }];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
::mlir::Value getOperandHandle() { return getTarget(); }
}];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
using namespace mlir;
ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
- Type &resultType) {
+ Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
- bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+ bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+ : parser.parseLParen().succeeded();
+ if (!resultOptional && !hasLParen)
+ return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
}
void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
- Type argumentType, Type resultType) {
+ Type argumentType, Type resultType,
+ bool resultOptional) {
+ assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e86d4962530a9a..4bbd9bfd1443f4 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error@below {{expected '('}}
+ %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
+}
|
|
Ping |
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have experience working with the parser, but it seems like it might not be necessary to thread this resultOptional option through the parser function. Can it be done without the new option?
| if (!resultOptional && !hasLParen) | ||
| return failure(); | ||
| if (parser.parseType(argumentType).failed()) | ||
| return failure(); | ||
| if (!hasLParen) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need the new resultOptional option? Does it work if you do something like:
argumentType = resultType = nullptr;
bool hasLParen = parser.parseOptionalLParen().succeeded();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
return success(parser.parseRParen().failed() ||
parser.parseArrow().failed() ||
parser.parseType(resultType).failed());
return failure(parser.parseRParen().failed() ||
parser.parseArrow().failed() ||
parser.parseType(resultType).failed());There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes problems for Ops with exactly one non optional result, like transform.structured.generalize. We could write an invalid version of such an op like this:
transform.structured.generalize %arg0 : !transform.any_opThis would successfully go through your parser because it looks like a valid SemiFunctionType with optional or no results. However, this will then lead to a crash during Op creation, because resultType is never initialized, yet it is not optional.
In general, if we have a non-optional result, we always either have to fail within the parser, or we have to correctly read the result's Type. Without the resultOptional argument enforcing the reading of the result, we can have ops slip through the cracks by using the form without parens and result Types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. In that case there is no way for the parser to know that the result is required. Considering that these ops must have a result, perhaps it would be better for these transform ops to use some different type of assembly format (that requires a result) instead of SemiFunctionType, but the extra option seems okay to me.
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay to me now, but maybe wait one more day for others to take a look.
|
@ftynse could you take a look (you are on git blame)? |
… `()` (llvm#110365) The `SemiFunctionType` allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory. This patch fixes a bug where omitting the parens around the argument types for a `SemiFunctionType` with non-optional result Types would crash the parser. It introduces a `bool` argument `resultOptional` to the parser and printer which, when `false`, correctly enforces the parens around argument types, otherwise printing an error. Fix llvm#109128
The
SemiFunctionTypeallows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory.This patch fixes a bug where omitting the parens around the argument types for a
SemiFunctionTypewith non-optional result Types would crash the parser. It introduces aboolargumentresultOptionalto the parser and printer which, whenfalse, correctly enforces the parens around argument types, otherwise printing an error.Fix #109128