Skip to content

Conversation

@ubfx
Copy link
Member

@ubfx ubfx commented Sep 28, 2024

The SemiFunctionType allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types for a SemiFunctionType with non-optional result Types would crash the parser. It introduces a bool argument resultOptional to the parser and printer which, when false, correctly enforces the parens around argument types, otherwise printing an error.

Fix #109128

The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types
for a `SemiFunctionType` with non-optional result Types would crash the
parser. It introduces a `bool` argument `resultOptional` to the parser
and printer which, when `false`, correctly enforces the parens around
argument types, otherwise printing an error.

Fix llvm#109128
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir-linalg

Author: Felix Schneider (ubfx)

Changes

The SemiFunctionType allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types for a SemiFunctionType with non-optional result Types would crash the parser. It introduces a bool argument resultOptional to the parser and printer which, when false, correctly enforces the parens around argument types, otherwise printing an error.

Fix #109128


Full diff: https://github.com/llvm/llvm-project/pull/110365.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td (+4-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+21-16)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h (+3-2)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td (+4-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp (+9-3)
  • (modified) mlir/test/Dialect/Linalg/transform-ops-invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
 
   let arguments = (ins TransformHandleTypeInterface:$operand_handle);
   let results = (outs TransformParamTypeInterface:$rank);
-  let assemblyFormat =
-      "$operand_handle attr-dict `:`"
-      "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+  let assemblyFormat = [{
+      $operand_handle attr-dict `:`
+      custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+  }];
 
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..efbba1eb065dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   let assemblyFormat = [{
     $target
     (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
-    `:` custom<SemiFunctionType>(type($target), type($transformed))
+    `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
   let hasVerifier = 1;
 
@@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
                        OptionalAttr<I64Attr>:$alignment);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$replacement);
   let regions = (region SizedRegion<1>:$bodyRegion);
-  let assemblyFormat =
-      "$target attr-dict-with-keyword regions `:` "
-      "custom<SemiFunctionType>(type($target), type($replacement))";
+  let assemblyFormat = [{
+      $target attr-dict-with-keyword regions `:`
+      custom<SemiFunctionType>(type($target), type($replacement), "false")
+  }];
   let hasVerifier = 1;
 }
 
@@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
 /// the argument type in absence of result types, and does not accept the
 /// trailing `-> ()` construct, which makes the syntax nicer for operations.
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                  Type &resultType);
+                                  Type &resultType, bool resultOptional = true);
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
                                   SmallVectorImpl<Type> &resultTypes);
 
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
                            Type argumentType, TypeRange resultType);
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                           Type argumentType, Type resultType);
+                           Type argumentType, Type resultType,
+                           bool resultOptional = true);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
     ::mlir::Value getOperandHandle() { return getTarget(); }
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
 using namespace mlir;
 
 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                        Type &resultType) {
+                                        Type &resultType, bool resultOptional) {
   argumentType = resultType = nullptr;
-  bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+  bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+                                  : parser.parseLParen().succeeded();
+  if (!resultOptional && !hasLParen)
+    return failure();
   if (parser.parseType(argumentType).failed())
     return failure();
   if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
 }
 
 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                                 Type argumentType, Type resultType) {
+                                 Type argumentType, Type resultType,
+                                 bool resultOptional) {
+  assert(resultOptional || resultType != nullptr);
   return printSemiFunctionType(printer, op, argumentType,
                                resultType ? TypeRange(resultType)
                                           : TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e86d4962530a9a..4bbd9bfd1443f4 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
   transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
 
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error@below {{expected '('}}
+  %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op 
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2024

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

Changes

The SemiFunctionType allows printing/parsing a set of argument and result types, where there is always exactly one argument type and zero or more result types. If there are no result types, the argument type can be written without enclosing parens in the assembly. If there is at least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument types for a SemiFunctionType with non-optional result Types would crash the parser. It introduces a bool argument resultOptional to the parser and printer which, when false, correctly enforces the parens around argument types, otherwise printing an error.

Fix #109128


Full diff: https://github.com/llvm/llvm-project/pull/110365.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td (+4-3)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+21-16)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h (+3-2)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td (+4-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp (+9-3)
  • (modified) mlir/test/Dialect/Linalg/transform-ops-invalid.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index cdc29d053e5a4b..2da52bbf861668 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [
 
   let arguments = (ins TransformHandleTypeInterface:$operand_handle);
   let results = (outs TransformParamTypeInterface:$rank);
-  let assemblyFormat =
-      "$operand_handle attr-dict `:`"
-      "custom<SemiFunctionType>(type($operand_handle), type($rank))";
+  let assemblyFormat = [{
+      $operand_handle attr-dict `:`
+      custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
+  }];
 
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
 }
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 106f0d79d9792d..efbba1eb065dca 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",
 
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$transformed);
-  let assemblyFormat =
-      "$target attr-dict `:` "
-      "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+      $target attr-dict `:` 
+      custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
   let assemblyFormat = [{
     $target
     (`iterator_interchange` `=` $iterator_interchange^)? attr-dict
-    `:` custom<SemiFunctionType>(type($target), type($transformed))
+    `:` custom<SemiFunctionType>(type($target), type($transformed), "false")
   }];
   let hasVerifier = 1;
 
@@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
                        OptionalAttr<I64Attr>:$alignment);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($transformed))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($transformed), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
@@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$replacement);
   let regions = (region SizedRegion<1>:$bodyRegion);
-  let assemblyFormat =
-      "$target attr-dict-with-keyword regions `:` "
-      "custom<SemiFunctionType>(type($target), type($replacement))";
+  let assemblyFormat = [{
+      $target attr-dict-with-keyword regions `:`
+      custom<SemiFunctionType>(type($target), type($replacement), "false")
+  }];
   let hasVerifier = 1;
 }
 
@@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat =
-    "$target attr-dict `:`"
-    "custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
 
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
index 50e55e72226120..595e8aac1045fe 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
@@ -30,7 +30,7 @@ class Operation;
 /// the argument type in absence of result types, and does not accept the
 /// trailing `-> ()` construct, which makes the syntax nicer for operations.
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                  Type &resultType);
+                                  Type &resultType, bool resultOptional = true);
 ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
                                   SmallVectorImpl<Type> &resultTypes);
 
@@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
                            Type argumentType, TypeRange resultType);
 void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                           Type argumentType, Type resultType);
+                           Type argumentType, Type resultType,
+                           bool resultOptional = true);
 } // namespace mlir
 
 #endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
index e340228795cdef..44eac878394b86 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td
@@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
   let arguments = (ins TransformHandleTypeInterface:$target);
   let results = (outs TransformHandleTypeInterface:$result);
 
-  let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
+  let assemblyFormat = [{
+    $target attr-dict `:`
+    custom<SemiFunctionType>(type($target), type($result), "false")
+  }];
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
     ::mlir::Value getOperandHandle() { return getTarget(); }
   }];
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
index 7ba0a6eb68f48c..266c9ad3314a32 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
@@ -12,9 +12,13 @@
 using namespace mlir;
 
 ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
-                                        Type &resultType) {
+                                        Type &resultType, bool resultOptional) {
   argumentType = resultType = nullptr;
-  bool hasLParen = parser.parseOptionalLParen().succeeded();
+
+  bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
+                                  : parser.parseLParen().succeeded();
+  if (!resultOptional && !hasLParen)
+    return failure();
   if (parser.parseType(argumentType).failed())
     return failure();
   if (!hasLParen)
@@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
 }
 
 void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
-                                 Type argumentType, Type resultType) {
+                                 Type argumentType, Type resultType,
+                                 bool resultOptional) {
+  assert(resultOptional || resultType != nullptr);
   return printSemiFunctionType(printer, op, argumentType,
                                resultType ? TypeRange(resultType)
                                           : TypeRange());
diff --git a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
index e86d4962530a9a..4bbd9bfd1443f4 100644
--- a/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
@@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
   transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>
 
 }
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+  // expected-error@below {{expected '('}}
+  %res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op 
+}

@ubfx ubfx requested a review from Max191 October 1, 2024 08:43
@ubfx
Copy link
Member Author

ubfx commented Oct 20, 2024

Ping

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have experience working with the parser, but it seems like it might not be necessary to thread this resultOptional option through the parser function. Can it be done without the new option?

Comment on lines +20 to 24
if (!resultOptional && !hasLParen)
return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need the new resultOptional option? Does it work if you do something like:

  argumentType = resultType = nullptr;
  bool hasLParen = parser.parseOptionalLParen().succeeded();
  if (parser.parseType(argumentType).failed())
    return failure();
  if (!hasLParen)
    return success(parser.parseRParen().failed() ||
                   parser.parseArrow().failed() ||
                   parser.parseType(resultType).failed());
  return failure(parser.parseRParen().failed() ||
                 parser.parseArrow().failed() ||
                 parser.parseType(resultType).failed());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes problems for Ops with exactly one non optional result, like transform.structured.generalize. We could write an invalid version of such an op like this:

transform.structured.generalize %arg0 : !transform.any_op

This would successfully go through your parser because it looks like a valid SemiFunctionType with optional or no results. However, this will then lead to a crash during Op creation, because resultType is never initialized, yet it is not optional.
In general, if we have a non-optional result, we always either have to fail within the parser, or we have to correctly read the result's Type. Without the resultOptional argument enforcing the reading of the result, we can have ops slip through the cracks by using the form without parens and result Types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case there is no way for the parser to know that the result is required. Considering that these ops must have a result, perhaps it would be better for these transform ops to use some different type of assembly format (that requires a result) instead of SemiFunctionType, but the extra option seems okay to me.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems okay to me now, but maybe wait one more day for others to take a look.

@Max191
Copy link
Contributor

Max191 commented Oct 22, 2024

@ftynse could you take a look (you are on git blame)?

@ubfx ubfx merged commit a07b422 into llvm:main Nov 3, 2024
8 checks passed
@ubfx ubfx deleted the linalg-semifunctiontype-parse branch November 3, 2024 14:31
PhilippRados pushed a commit to PhilippRados/llvm-project that referenced this pull request Nov 6, 2024
… `()` (llvm#110365)

The `SemiFunctionType` allows printing/parsing a set of argument and
result types, where there is always exactly one argument type and zero
or more result types. If there are no result types, the argument type
can be written without enclosing parens in the assembly. If there is at
least one result type, the parens are mandatory.

This patch fixes a bug where omitting the parens around the argument
types for a `SemiFunctionType` with non-optional result Types would
crash the parser. It introduces a `bool` argument `resultOptional` to
the parser and printer which, when `false`, correctly enforces the
parens around argument types, otherwise printing an error.

Fix llvm#109128
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:linalg mlir:sparse Sparse compiler in MLIR mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] [linalg] Generalize linalg.depthwise_conv_2d_nhwc_hwcm crash

3 participants