Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,10 @@ def MatchStructuredRankOp : Op<Transform_Dialect, "match.structured.rank", [

let arguments = (ins TransformHandleTypeInterface:$operand_handle);
let results = (outs TransformParamTypeInterface:$rank);
let assemblyFormat =
"$operand_handle attr-dict `:`"
"custom<SemiFunctionType>(type($operand_handle), type($rank))";
let assemblyFormat = [{
$operand_handle attr-dict `:`
custom<SemiFunctionType>(type($operand_handle), type($rank), "false")
}];

let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` "
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down Expand Up @@ -419,9 +420,10 @@ def SpecializeOp : Op<Transform_Dialect, "structured.specialize",

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat =
"$target attr-dict `:` "
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down Expand Up @@ -464,7 +466,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
let assemblyFormat = [{
$target
(`iterator_interchange` `=` $iterator_interchange^)? attr-dict
`:` custom<SemiFunctionType>(type($target), type($transformed))
`:` custom<SemiFunctionType>(type($target), type($transformed), "false")
}];
let hasVerifier = 1;

Expand Down Expand Up @@ -1197,9 +1199,10 @@ def PromoteOp : Op<Transform_Dialect, "structured.promote",
OptionalAttr<I64Attr>:$alignment);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:`"
"custom<SemiFunctionType>(type($target), type($transformed))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($transformed), "false")
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down Expand Up @@ -1233,9 +1236,10 @@ def ReplaceOp : Op<Transform_Dialect, "structured.replace",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$replacement);
let regions = (region SizedRegion<1>:$bodyRegion);
let assemblyFormat =
"$target attr-dict-with-keyword regions `:` "
"custom<SemiFunctionType>(type($target), type($replacement))";
let assemblyFormat = [{
$target attr-dict-with-keyword regions `:`
custom<SemiFunctionType>(type($target), type($replacement), "false")
}];
let hasVerifier = 1;
}

Expand Down Expand Up @@ -1274,9 +1278,10 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat =
"$target attr-dict `:`"
"custom<SemiFunctionType>(type($target), type($result))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($result), "false")
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/TransformOps/Syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Operation;
/// the argument type in absence of result types, and does not accept the
/// trailing `-> ()` construct, which makes the syntax nicer for operations.
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
Type &resultType);
Type &resultType, bool resultOptional = true);
ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
SmallVectorImpl<Type> &resultTypes);

Expand All @@ -40,7 +40,8 @@ ParseResult parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, TypeRange resultType);
void printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, Type resultType);
Type argumentType, Type resultType,
bool resultOptional = true);
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_TRANSFORMOPS_SYNTAX_H
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def MatchSparseInOut : Op<Transform_Dialect, "sparse_tensor.match.sparse_inout",
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = "$target attr-dict `:` custom<SemiFunctionType>(type($target), type($result))";
let assemblyFormat = [{
$target attr-dict `:`
custom<SemiFunctionType>(type($target), type($result), "false")
}];
let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{
::mlir::Value getOperandHandle() { return getTarget(); }
}];
Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Dialect/Linalg/TransformOps/Syntax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
using namespace mlir;

ParseResult mlir::parseSemiFunctionType(OpAsmParser &parser, Type &argumentType,
Type &resultType) {
Type &resultType, bool resultOptional) {
argumentType = resultType = nullptr;
bool hasLParen = parser.parseOptionalLParen().succeeded();

bool hasLParen = resultOptional ? parser.parseOptionalLParen().succeeded()
: parser.parseLParen().succeeded();
if (!resultOptional && !hasLParen)
return failure();
if (parser.parseType(argumentType).failed())
return failure();
if (!hasLParen)
Comment on lines +20 to 24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need the new resultOptional option? Does it work if you do something like:

  argumentType = resultType = nullptr;
  bool hasLParen = parser.parseOptionalLParen().succeeded();
  if (parser.parseType(argumentType).failed())
    return failure();
  if (!hasLParen)
    return success(parser.parseRParen().failed() ||
                   parser.parseArrow().failed() ||
                   parser.parseType(resultType).failed());
  return failure(parser.parseRParen().failed() ||
                 parser.parseArrow().failed() ||
                 parser.parseType(resultType).failed());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes problems for Ops with exactly one non optional result, like transform.structured.generalize. We could write an invalid version of such an op like this:

transform.structured.generalize %arg0 : !transform.any_op

This would successfully go through your parser because it looks like a valid SemiFunctionType with optional or no results. However, this will then lead to a crash during Op creation, because resultType is never initialized, yet it is not optional.
In general, if we have a non-optional result, we always either have to fail within the parser, or we have to correctly read the result's Type. Without the resultOptional argument enforcing the reading of the result, we can have ops slip through the cracks by using the form without parens and result Types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case there is no way for the parser to know that the result is required. Considering that these ops must have a result, perhaps it would be better for these transform ops to use some different type of assembly format (that requires a result) instead of SemiFunctionType, but the extra option seems okay to me.

Expand Down Expand Up @@ -69,7 +73,9 @@ void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
}

void mlir::printSemiFunctionType(OpAsmPrinter &printer, Operation *op,
Type argumentType, Type resultType) {
Type argumentType, Type resultType,
bool resultOptional) {
assert(resultOptional || resultType != nullptr);
return printSemiFunctionType(printer, op, argumentType,
resultType ? TypeRange(resultType)
: TypeRange());
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/Linalg/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,11 @@ transform.sequence failures(propagate) {
transform.structured.vectorize %arg0 vector_sizes [%0 : !transform.param<i64>, 2] : !transform.any_op, !transform.param<i64>

}

// -----

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-error@below {{expected '('}}
%res = transform.structured.generalize %arg0 : !transform.any_op -> !transform.any_op
}
Loading