diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md index 44be4c8ed3dcc..0db84a5f0fd8b 100644 --- a/mlir/docs/DefiningDialects/AttributesAndTypes.md +++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md @@ -842,9 +842,9 @@ if they are not present. ###### `struct` Directive -The `struct` directive accepts a list of variables to capture and will generate -a parser and printer for a comma-separated list of key-value pairs. If an -optional parameter is included in the `struct`, it can be elided. The variables +The `struct` directive accepts a list of variables or directives to capture and +will generate a parser and printer for a comma-separated list of key-value pairs. +If an optional parameter is included in the `struct`, it can be elided. The variables are printed in the order they are specified in the argument list **but can be parsed in any order**. For example: @@ -876,6 +876,13 @@ assembly format of `` `<` struct(params) `>` `` will result in: The order in which the parameters are printed is the order in which they are declared in the attribute's or type's `parameter` list. +Passing `custom($variable)` allows providing a custom printer and parser +for the encapsulated variable. Check the +[custom and ref directive](#custom-and-ref-directive) section for more +information about how to define the printer and parser functions. Note that a +custom directive within a struct directive can only encapsulate a single +variable. + ###### `custom` and `ref` directive The `custom` directive is used to dispatch calls to user-defined printer and diff --git a/mlir/test/IR/custom-struct-attr-roundtrip.mlir b/mlir/test/IR/custom-struct-attr-roundtrip.mlir new file mode 100644 index 0000000000000..a8873d01f1df7 --- /dev/null +++ b/mlir/test/IR/custom-struct-attr-roundtrip.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s + +// CHECK-LABEL: @test_struct_attr_roundtrip +func.func @test_struct_attr_roundtrip() -> () { + // CHECK: attr = #test.custom_struct + "test.op"() {attr = #test.custom_struct} : () -> () + // CHECK: attr = #test.custom_struct + "test.op"() {attr = #test.custom_struct} : () -> () + // CHECK: attr = #test.custom_struct + "test.op"() {attr = #test.custom_struct} : () -> () + // CHECK: attr = #test.custom_struct + "test.op"() {attr = #test.custom_struct} : () -> () + // CHECK: attr = #test.custom_struct + "test.op"() {attr = #test.custom_struct} : () -> () + return +} + +// ----- + +// Verify all required parameters must be provided. `value` is missing. + +// expected-error @below {{struct is missing required parameter: value}} +"test.op"() {attr = #test.custom_struct} : () -> () + +// ----- + +// Verify all keywords must be provided. All missing. + +// expected-error @below {{expected valid keyword}} +// expected-error @below {{expected a parameter name in struct}} +"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> () + +// ----- + +// Verify all keywords must be provided. `type_str` missing. + +// expected-error @below {{expected valid keyword}} +// expected-error @below {{expected a parameter name in struct}} +"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> () + +// ----- + +// Verify all keywords must be provided. `value` missing. + +// expected-error @below {{expected valid keyword}} +// expected-error @below {{expected a parameter name in struct}} +"test.op"() {attr = #test.custom_struct} : () -> () + +// ----- + +// Verify invalid keyword provided. + +// expected-error @below {{duplicate or unknown struct parameter name: type_str2}} +"test.op"() {attr = #test.custom_struct} : () -> () + +// ----- + +// Verify duplicated keyword provided. + +// expected-error @below {{duplicate or unknown struct parameter name: type_str}} +"test.op"() {attr = #test.custom_struct} : () -> () + +// ----- + +// Verify equals missing. + +// expected-error @below {{expected '='}} +"test.op"() {attr = #test.custom_struct} : () -> () diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index fc2d77af29f12..6441a82d87eba 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> { }]; } +// Test `struct` with nested `custom` assembly format. +def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> { + let mnemonic = "custom_struct"; + let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value, + OptionalParameter<"mlir::ArrayAttr">:$opt_value); + let assemblyFormat = [{ + `<` struct($type_str, custom($value), custom($opt_value)) `>` + }]; +} + def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> { let mnemonic = "nested_polynomial"; let parameters = (ins Polynomial_IntPolynomialAttr:$poly); diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index 057d9fb4a215f..cd69eaf8014f6 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -316,6 +316,49 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr, return success(); } +//===----------------------------------------------------------------------===// +// TestCustomStructAttr +//===----------------------------------------------------------------------===// + +static void printCustomStructAttr(AsmPrinter &p, int64_t value) { + if (ShapedType::isDynamic(value)) { + p << "?"; + } else { + p.printStrippedAttrOrType(value); + } +} + +static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) { + if (succeeded(p.parseOptionalQuestion())) { + value = ShapedType::kDynamic; + return success(); + } + return p.parseInteger(value); +} + +static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) { + if (attr && attr.size() == 1 && isa(attr[0])) { + p << cast(attr[0]).getInt(); + } else { + p.printStrippedAttrOrType(attr); + } +} + +static ParseResult parseCustomOptStructFieldAttr(AsmParser &p, + ArrayAttr &attr) { + int64_t value; + OptionalParseResult result = p.parseOptionalInteger(value); + if (result.has_value()) { + if (failed(result.value())) + return failure(); + attr = ArrayAttr::get( + p.getContext(), + {IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)}); + return success(); + } + return p.parseAttribute(attr); +} + //===----------------------------------------------------------------------===// // TestOpAsmAttrInterfaceAttr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td index 3a57cbca4d7bb..9a521a5053c0f 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td @@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> { def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> { let parameters = (ins "int":$v0); // CHECK: literals may only be used in the top-level section of the format - // CHECK: expected a variable in `struct` argument list + // CHECK: expected a parameter or `custom` directive in `struct` argument list let assemblyFormat = "`<` struct($v0, `,`) `>`"; } // Test struct directive cannot capture zero parameters. def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> { let parameters = (ins "int":$v0); - // CHECK: `struct` argument list expected a variable or directive + // CHECK: `struct` argument list expected a parameter or directive let assemblyFormat = "`<` struct() $v0 `>`"; } @@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> { // CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor let assemblyFormat = "$a (`(` custom(ref($a))^ `)`)?"; } + +// Test `struct` with nested `custom` directive with multiple fields. +def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: `struct` can only contain `custom` directives with a single argument + let assemblyFormat = "struct(custom($a, $b))"; +} + +// Test `struct` with nested `custom` directive invalid parameter. +def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> { + let parameters = (ins OptionalParameter<"int">:$a); + // CHECK: a `custom` directive nested within a `struct` must be passed a parameter + let assemblyFormat = "struct($a, custom(ref($a)))"; +} + +// Test `custom` with nested `custom` directive invalid parameter. +def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> { + let parameters = (ins OptionalParameter<"int">:$a, "int":$b); + // CHECK: `custom` can only be used at the top-level context or within a `struct` directive + let assemblyFormat = "custom($a, custom($b))"; +} diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td index c5348409e8e44..0f6b0c401a4e6 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> { let assemblyFormat = "$a"; } +/// Test that a `struct` with nested `custom` parser and printer are generated correctly. + +// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser, +// ATTR: ::mlir::Type odsType) { +// ATTR: bool _seen_v0 = false; +// ATTR: bool _seen_v1 = false; +// ATTR: bool _seen_v2 = false; +// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool { +// ATTR: if (odsParser.parseEqual()) +// ATTR: return {}; +// ATTR: if (!_seen_v0 && _paramKey == "v0") { +// ATTR: _seen_v0 = true; +// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType); +// ATTR: if (::mlir::failed(_result_v0)) +// ATTR: return {}; +// ATTR: } else if (!_seen_v1 && _paramKey == "v1") { +// ATTR: _seen_v1 = true; +// ATTR: { +// ATTR: auto odsCustomResult = parseNestedCustom(odsParser, +// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1)); +// ATTR: if (::mlir::failed(odsCustomResult)) return {}; +// ATTR: if (::mlir::failed(_result_v1)) { +// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'"); +// ATTR: return {}; +// ATTR: } +// ATTR: } +// ATTR: } else if (!_seen_v2 && _paramKey == "v2") { +// ATTR: _seen_v2 = true; +// ATTR: _result_v2 = ::mlir::FieldParser::parse(odsParser); +// ATTR: if (::mlir::failed(_result_v2)) { +// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`"); +// ATTR: return {}; +// ATTR: } +// ATTR: } else { +// ATTR: return {}; +// ATTR: } +// ATTR: return true; +// ATTR: } +// ATTR: do { +// ATTR: ::llvm::StringRef _paramKey; +// ATTR: if (odsParser.parseKeyword(&_paramKey)) { +// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), +// ATTR-NEXT: "expected a parameter name in struct"); +// ATTR: return {}; +// ATTR: } +// ATTR: if (!_loop_body(_paramKey)) return {}; +// ATTR: } while(!odsParser.parseOptionalComma()); +// ATTR: if (!_seen_v0) +// ATTR: if (!_seen_v1) +// ATTR: return TestTAttr::get(odsParser.getContext(), +// ATTR: TestParamA((*_result_v0)), +// ATTR: TestParamB((*_result_v1)), +// ATTR: AttrParamB((_result_v2.value_or(AttrParamB())))); +// ATTR: } + +// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const { +// ATTR: odsPrinter << "v0 = "; +// ATTR: ::printAttrParamA(odsPrinter, getV0()); +// ATTR: odsPrinter << ", "; +// ATTR: odsPrinter << "v1 = "; +// ATTR: printNestedCustom(odsPrinter, +// ATTR-NEXT: getV1()); +// ATTR: if (!(getV2() == AttrParamB())) { +// ATTR: odsPrinter << "v2 = "; +// ATTR: odsPrinter.printStrippedAttrOrType(getV2()); +// ATTR: } + +def AttrT : TestAttr<"TestT"> { + let parameters = (ins + AttrParamA:$v0, + AttrParamB:$v1, + OptionalParameter<"AttrParamB">:$v2 + ); + + let mnemonic = "attr_t"; + let assemblyFormat = "`{` struct($v0, custom($v1), $v2) `}`"; +} + // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser) // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index a4ae271edb6bd..5c5303e27fde3 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" @@ -75,37 +76,35 @@ class ParameterElement AttrOrTypeParameter param; }; +/// Utility to return the encapsulated parameter element for the provided format +/// element. This parameter can originate from either a `ParameterElement`, +/// `CustomDirective` with a single parameter argument or `RefDirective`. +static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) { + return TypeSwitch(el) + .Case([&](auto custom) { + FailureOr maybeParam = + custom->template getFrontAs(); + return *maybeParam; + }) + .Case([&](auto param) { return param; }) + .Case( + [&](auto ref) { return cast(ref->getArg()); }) + .Default([&](auto el) { + assert(false && "unexpected struct element type"); + return nullptr; + }); +} + /// Shorthand functions that can be used with ranged-based conditions. static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } +static bool formatIsOptional(FormatElement *el) { + ParameterElement *param = getEncapsulatedParameterElement(el); + return param != nullptr && param->isOptional(); +} static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); } - -/// Base class for a directive that contains references to multiple variables. -template -class ParamsDirectiveBase : public DirectiveElementBase { -public: - using Base = ParamsDirectiveBase; - - ParamsDirectiveBase(std::vector &¶ms) - : params(std::move(params)) {} - - /// Get the parameters contained in this directive. - ArrayRef getParams() const { return params; } - - /// Get the number of parameters. - unsigned getNumParams() const { return params.size(); } - - /// Take all of the parameters from this directive. - std::vector takeParams() { return std::move(params); } - - /// Returns true if there are optional parameters present. - bool hasOptionalParams() const { - return llvm::any_of(getParams(), paramIsOptional); - } - -private: - /// The parameters captured by this directive. - std::vector params; -}; +static bool formatNotOptional(FormatElement *el) { + return !formatIsOptional(el); +} /// This class represents a `params` directive that refers to all parameters /// of an attribute or type. When used as a top-level directive, it generates @@ -116,9 +115,15 @@ class ParamsDirectiveBase : public DirectiveElementBase { /// When used as an argument to another directive that accepts variables, /// `params` can be used in place of manually listing all parameters of an /// attribute or type. -class ParamsDirective : public ParamsDirectiveBase { +class ParamsDirective + : public VectorDirectiveBase { public: using Base::Base; + + /// Returns true if there are optional parameters present. + bool hasOptionalElements() const { + return llvm::any_of(getElements(), paramIsOptional); + } }; /// This class represents a `struct` directive that generates a struct format @@ -126,9 +131,15 @@ class ParamsDirective : public ParamsDirectiveBase { /// /// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}` /// -class StructDirective : public ParamsDirectiveBase { +class StructDirective + : public VectorDirectiveBase { public: using Base::Base; + + /// Returns true if there are optional format elements present. + bool hasOptionalElements() const { + return llvm::any_of(getElements(), formatIsOptional); + } }; } // namespace @@ -214,10 +225,10 @@ class DefFormat { /// Generate the printer code for a variable. void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, bool skipGuard = false); - /// Generate a printer for comma-separated parameters. - void genCommaSeparatedPrinter(ArrayRef params, + /// Generate a printer for comma-separated format elements. + void genCommaSeparatedPrinter(ArrayRef params, FmtContext &ctx, MethodBody &os, - function_ref extra); + function_ref extra); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. @@ -443,14 +454,14 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, // If there are optional parameters, we need to switch to `parseOptionalComma` // if there are no more required parameters after a certain point. - bool hasOptional = el->hasOptionalParams(); + bool hasOptional = el->hasOptionalElements(); if (hasOptional) { // Wrap everything in a do-while so that we can `break`. os << "do {\n"; os.indent(); } - ArrayRef params = el->getParams(); + ArrayRef params = el->getElements(); using IteratorT = ParameterElement *const *; IteratorT it = params.begin(); @@ -551,22 +562,31 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, while (!$_parser.parseOptionalComma()) { )"; + const char *const checkParamKey = R"( + if (!_seen_{0} && _paramKey == "{0}") { + _seen_{0} = true; +)"; + os << "// Parse parameter struct\n"; // Declare a "seen" variable for each key. - for (ParameterElement *param : el->getParams()) + for (FormatElement *arg : el->getElements()) { + ParameterElement *param = getEncapsulatedParameterElement(arg); os << formatv("bool _seen_{0} = false;\n", param->getName()); + } // Generate the body of the parsing loop inside a lambda. os << "{\n"; os.indent() << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; genLiteralParser("=", ctx, os.indent()); - for (ParameterElement *param : el->getParams()) { - os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" - " _seen_{0} = true;\n", - param->getName()); - genVariableParser(param, ctx, os.indent()); + for (FormatElement *arg : el->getElements()) { + ParameterElement *param = getEncapsulatedParameterElement(arg); + os.getStream().printReindented(strfmt(checkParamKey, param->getName())); + if (auto realParam = dyn_cast(arg)) + genVariableParser(param, ctx, os.indent()); + else if (auto custom = dyn_cast(arg)) + genCustomParser(custom, ctx, os.indent()); os.unindent() << "} else "; // Print the check for duplicate or unknown parameter. } @@ -576,10 +596,10 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, // Generate the parsing loop. If optional parameters are present, then the // parse loop is guarded by commas. - unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional); + unsigned numOptional = llvm::count_if(el->getElements(), formatIsOptional); if (numOptional) { // If the struct itself is optional, pull out the first iteration. - if (numOptional == el->getNumParams()) { + if (numOptional == el->getNumElements()) { os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str()); os.indent(); } else { @@ -587,7 +607,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, } } else { os.getStream().printReindented( - tgfmt(loopHeader, &ctx, el->getNumParams()).str()); + tgfmt(loopHeader, &ctx, el->getNumElements()).str()); } os.indent(); os.getStream().printReindented(tgfmt(loopStart, &ctx).str()); @@ -597,12 +617,13 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, // all mandatory parameters have been parsed. // The whole struct is optional if all its parameters are optional. if (numOptional) { - if (numOptional == el->getNumParams()) { + if (numOptional == el->getNumElements()) { os << "}\n"; os.unindent() << "}\n"; } else { os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx); - for (ParameterElement *param : el->getParams()) { + for (FormatElement *arg : el->getElements()) { + ParameterElement *param = getEncapsulatedParameterElement(arg); if (param->isOptional()) continue; os.getStream().printReindented( @@ -614,7 +635,8 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, // N flags, successfully exiting the loop means that all parameters have // been seen. `parseOptionalComma` would cause issues with any formats that // use "struct(...) `,`" beacuse structs aren't sounded by braces. - os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams())); + os.getStream().printReindented( + strfmt(loopTerminator, el->getNumElements())); } os.unindent() << "}\n"; } @@ -631,7 +653,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, os << "(void)odsCustomLoc;\n"; os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName()); os.indent(); - for (FormatElement *arg : el->getArguments()) { + for (FormatElement *arg : el->getElements()) { os << ",\n"; if (auto *param = dyn_cast(arg)) os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName() @@ -648,7 +670,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx, } else { os << "if (::mlir::failed(odsCustomResult)) return {};\n"; } - for (FormatElement *arg : el->getArguments()) { + for (FormatElement *arg : el->getElements()) { if (auto *param = dyn_cast(arg)) { if (param->isOptional()) continue; @@ -689,7 +711,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, guardOn(llvm::ArrayRef(param)); } else if (auto *params = dyn_cast(first)) { genParamsParser(params, ctx, os); - guardOn(params->getParams()); + guardOn(params->getElements()); } else if (auto *custom = dyn_cast(first)) { os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n"; os.indent(); @@ -704,7 +726,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx, } else { auto *strct = cast(first); genStructParser(strct, ctx, os); - guardOn(params->getParams()); + guardOn(params->getElements()); } os.indent(); @@ -816,14 +838,26 @@ static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &¶ms, os.indent(); } +/// Generate code to guard printing on the presence of any optional format +/// elements. +template +static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os, + FormatElemRange &&args, bool inverted = false) { + guardOnAny(ctx, os, + llvm::make_filter_range( + llvm::map_range(args, getEncapsulatedParameterElement), + [](ParameterElement *param) { return param->isOptional(); }), + inverted); +} + void DefFormat::genCommaSeparatedPrinter( - ArrayRef params, FmtContext &ctx, MethodBody &os, - function_ref extra) { + ArrayRef args, FmtContext &ctx, MethodBody &os, + function_ref extra) { // Emit a space if necessary, but only if the struct is present. if (shouldEmitSpace || !lastWasPunctuation) { - bool allOptional = llvm::all_of(params, paramIsOptional); + bool allOptional = llvm::all_of(args, formatIsOptional); if (allOptional) - guardOnAny(ctx, os, params); + guardOnAnyOptional(ctx, os, args); os << tgfmt("$_printer << ' ';\n", &ctx); if (allOptional) os.unindent() << "}\n"; @@ -832,17 +866,21 @@ void DefFormat::genCommaSeparatedPrinter( // The first printed element does not need to emit a comma. os << "{\n"; os.indent() << "bool _firstPrinted = true;\n"; - for (ParameterElement *param : params) { + for (FormatElement *arg : args) { + ParameterElement *param = getEncapsulatedParameterElement(arg); if (param->isOptional()) { param->genPrintGuard(ctx, os << "if (") << ") {\n"; os.indent(); } os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx); os << "_firstPrinted = false;\n"; - extra(param); + extra(arg); shouldEmitSpace = false; lastWasPunctuation = true; - genVariablePrinter(param, ctx, os); + if (auto realParam = dyn_cast(arg)) + genVariablePrinter(realParam, ctx, os); + else if (auto custom = dyn_cast(arg)) + genCustomPrinter(custom, ctx, os); if (param->isOptional()) os.unindent() << "}\n"; } @@ -851,16 +889,19 @@ void DefFormat::genCommaSeparatedPrinter( void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os) { - genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os, - [&](ParameterElement *param) {}); + SmallVector args = llvm::map_to_vector( + el->getElements(), [](ParameterElement *param) -> FormatElement * { + return static_cast(param); + }); + genCommaSeparatedPrinter(args, ctx, os, [&](FormatElement *param) {}); } void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os) { - genCommaSeparatedPrinter( - llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) { - os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); - }); + genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) { + ParameterElement *param = getEncapsulatedParameterElement(arg); + os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); + }); } void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, @@ -873,7 +914,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, os << tgfmt("print$0($_printer", &ctx, el->getName()); os.indent(); - for (FormatElement *arg : el->getArguments()) { + for (FormatElement *arg : el->getElements()) { os << ",\n"; if (auto *param = dyn_cast(arg)) { os << param->getParam().getAccessorName() << "()"; @@ -893,19 +934,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, if (auto *param = dyn_cast(anchor)) { guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted()); } else if (auto *params = dyn_cast(anchor)) { - guardOnAny(ctx, os, params->getParams(), el->isInverted()); + guardOnAny(ctx, os, params->getElements(), el->isInverted()); } else if (auto *strct = dyn_cast(anchor)) { - guardOnAny(ctx, os, strct->getParams(), el->isInverted()); + guardOnAnyOptional(ctx, os, strct->getElements(), el->isInverted()); } else { auto *custom = cast(anchor); - guardOnAny(ctx, os, - llvm::make_filter_range( - llvm::map_range(custom->getArguments(), - [](FormatElement *el) { - return dyn_cast(el); - }), - [](ParameterElement *param) { return !!param; }), - el->isInverted()); + guardOnAnyOptional(ctx, os, custom->getElements(), el->isInverted()); } // Generate the printer for the contained elements. { @@ -960,6 +994,9 @@ class DefFormatParser : public FormatParser { LogicalResult verifyOptionalGroupElements(SMLoc loc, ArrayRef elements, FormatElement *anchor) override; + /// Verify the arguments to a struct directive. + LogicalResult verifyStructArguments(SMLoc loc, + ArrayRef arguments); LogicalResult markQualified(SMLoc loc, FormatElement *element) override; @@ -1010,7 +1047,7 @@ LogicalResult DefFormatParser::verify(SMLoc loc, auto *literalEl = dyn_cast(std::get<1>(it)); if (!structEl || !literalEl) continue; - if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) { + if (literalEl->getSpelling() == "," && structEl->hasOptionalElements()) { return emitError(loc, "`struct` directive with optional parameters " "cannot be followed by a comma literal"); } @@ -1037,17 +1074,17 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, "parameters in an optional group must be optional"); } } else if (auto *params = dyn_cast(el)) { - if (llvm::any_of(params->getParams(), paramNotOptional)) { + if (llvm::any_of(params->getElements(), paramNotOptional)) { return emitError(loc, "`params` directive allowed in optional group " "only if all parameters are optional"); } } else if (auto *strct = dyn_cast(el)) { - if (llvm::any_of(strct->getParams(), paramNotOptional)) { + if (llvm::any_of(strct->getElements(), formatNotOptional)) { return emitError(loc, "`struct` is only allowed in an optional group " "if all captured parameters are optional"); } } else if (auto *custom = dyn_cast(el)) { - for (FormatElement *el : custom->getArguments()) { + for (FormatElement *el : custom->getElements()) { // If the custom argument is a variable, then it must be optional. if (auto *param = dyn_cast(el)) if (!param->isOptional()) @@ -1068,10 +1105,10 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, // arguments is a bound parameter. if (auto *custom = dyn_cast(anchor)) { const auto *bound = - llvm::find_if(custom->getArguments(), [](FormatElement *el) { + llvm::find_if(custom->getElements(), [](FormatElement *el) { return isa(el); }); - if (bound == custom->getArguments().end()) + if (bound == custom->getElements().end()) return emitError(loc, "`custom` directive with no bound parameters " "cannot be used as optional group anchor"); } @@ -1079,6 +1116,28 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc, return success(); } +LogicalResult +DefFormatParser::verifyStructArguments(SMLoc loc, + ArrayRef arguments) { + for (FormatElement *el : arguments) { + if (!isa(el)) { + return emitError(loc, "expected a parameter, custom directive or params " + "directive in `struct` arguments list"); + } + if (auto custom = dyn_cast(el)) { + if (custom->getNumElements() != 1) { + return emitError(loc, "`struct` can only contain `custom` directives " + "with a single argument"); + } + if (failed(custom->getFrontAs())) { + return emitError(loc, "a `custom` directive nested within a `struct` " + "must be passed a parameter"); + } + } + } + return success(); +} + LogicalResult DefFormatParser::markQualified(SMLoc loc, FormatElement *element) { if (!isa(element)) @@ -1172,37 +1231,45 @@ FailureOr DefFormatParser::parseStructDirective(SMLoc loc, return emitError(loc, "`struct` can only be used at the top-level context"); if (failed(parseToken(FormatToken::l_paren, - "expected '(' before `struct` argument list"))) + "expected '(' before `struct` argument list"))) { return failure(); + } // Parse variables captured by `struct`. - std::vector vars; + std::vector vars; // Parse first captured parameter or a `params` directive. FailureOr var = parseElement(StructDirectiveContext); - if (failed(var) || !isa(*var)) { - return emitError(loc, - "`struct` argument list expected a variable or directive"); + if (failed(var) || + !isa(*var)) { + return emitError( + loc, "`struct` argument list expected a parameter or directive"); } - if (isa(*var)) { + if (isa(*var)) { // Parse any other parameters. - vars.push_back(cast(*var)); + vars.push_back(*var); while (peekToken().is(FormatToken::comma)) { consumeToken(); var = parseElement(StructDirectiveContext); - if (failed(var) || !isa(*var)) - return emitError(loc, "expected a variable in `struct` argument list"); - vars.push_back(cast(*var)); + if (failed(var) || !isa(*var)) + return emitError(loc, "expected a parameter or `custom` directive in " + "`struct` argument list"); + vars.push_back(*var); } } else { // `struct(params)` captures all parameters in the attribute or type. - vars = cast(*var)->takeParams(); + ParamsDirective *params = cast(*var); + vars.reserve(params->getNumElements()); + for (ParameterElement *el : params->takeElements()) + vars.push_back(cast(el)); } if (failed(parseToken(FormatToken::r_paren, - "expected ')' at the end of an argument list"))) + "expected ')' at the end of an argument list"))) { + return failure(); + } + if (failed(verifyStructArguments(loc, vars))) return failure(); - return create(std::move(vars)); } diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp index dd9b41bc90aef..4dfdde2146679 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -400,8 +400,10 @@ FailureOr FormatParser::parseOptionalGroup(Context ctx) { FailureOr FormatParser::parseCustomDirective(SMLoc loc, Context ctx) { - if (ctx != TopLevelContext) - return emitError(loc, "'custom' is only valid as a top-level directive"); + if (ctx != TopLevelContext && ctx != StructDirectiveContext) { + return emitError(loc, "`custom` can only be used at the top-level context " + "or within a `struct` directive"); + } FailureOr nameTok; if (failed(parseToken(FormatToken::less, diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h index 1dc2cb3eaa88a..8e7d49bb37e71 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -338,29 +338,56 @@ class DirectiveElementBase : public DirectiveElement { } }; +/// Base class for a directive that contains references to elements of type `T` +/// in a vector. +template +class VectorDirectiveBase : public DirectiveElementBase { +public: + using Base = VectorDirectiveBase; + + VectorDirectiveBase(std::vector &&elems) : elems(std::move(elems)) {} + + /// Get the elements contained in this directive. + ArrayRef getElements() const { return elems; } + + /// Get the number of elements. + unsigned getNumElements() const { return elems.size(); } + + /// Take all of the elements from this directive. + std::vector takeElements() { return std::move(elems); } + +protected: + /// The elements captured by this directive. + std::vector elems; +}; + /// This class represents a custom format directive that is implemented by the /// user in C++. The directive accepts a list of arguments that is passed to the /// C++ function. -class CustomDirective : public DirectiveElementBase { +class CustomDirective + : public VectorDirectiveBase { public: + using Base::Base; /// Create a custom directive with a name and list of arguments. CustomDirective(StringRef name, std::vector &&arguments) - : name(name), arguments(std::move(arguments)) {} + : Base(std::move(arguments)), name(name) {} /// Get the custom directive name. StringRef getName() const { return name; } - /// Get the arguments to the custom directive. - ArrayRef getArguments() const { return arguments; } + template + FailureOr getFrontAs() const { + if (getNumElements() != 1) + return failure(); + if (T *elem = dyn_cast(getElements()[0])) + return elem; + return failure(); + } private: /// The name of the custom directive. The name is used to call two C++ /// methods: `parse{name}` and `print{name}` with the given arguments. StringRef name; - /// The arguments with which to call the custom functions. These are either - /// variables (for which the functions are responsible for populating) or - /// references to variables. - std::vector arguments; }; /// This class represents a reference directive. This directive can be used to diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index ca2c1d4a8ad04..a0d947fe8a0df 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -882,7 +882,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op, } } else if (auto *custom = dyn_cast(element)) { - for (FormatElement *paramElement : custom->getArguments()) + for (FormatElement *paramElement : custom->getElements()) genElementParserStorage(paramElement, op, body); } else if (isa(element)) { @@ -1037,7 +1037,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, // * Add a local variable for optional operands and types. This provides a // better API to the user defined parser methods. // * Set the location of operand variables. - for (FormatElement *param : dir->getArguments()) { + for (FormatElement *param : dir->getElements()) { if (auto *operand = dyn_cast(param)) { auto *var = operand->getVar(); body << " " << var->name @@ -1089,7 +1089,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, } body << " auto odsResult = parse" << dir->getName() << "(parser"; - for (FormatElement *param : dir->getArguments()) { + for (FormatElement *param : dir->getElements()) { body << ", "; genCustomParameterParser(param, body); } @@ -1103,7 +1103,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body, } // After parsing, add handling for any of the optional constructs. - for (FormatElement *param : dir->getArguments()) { + for (FormatElement *param : dir->getElements()) { if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); if (var->attr.isOptional() || var->attr.hasDefaultValue()) @@ -2215,7 +2215,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element, static void genCustomDirectivePrinter(CustomDirective *customDir, const Operator &op, MethodBody &body) { body << " print" << customDir->getName() << "(_odsPrinter, *this"; - for (FormatElement *param : customDir->getArguments()) { + for (FormatElement *param : customDir->getElements()) { body << ", "; genCustomDirectiveParameterPrinter(param, op, body); } @@ -2359,7 +2359,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor, .Case([&](CustomDirective *ele) { body << '('; llvm::interleave( - ele->getArguments(), body, + ele->getElements(), body, [&](FormatElement *child) { body << '('; genOptionalGroupPrinterAnchor(child, op, body); @@ -2375,7 +2375,7 @@ void collect(FormatElement *element, TypeSwitch(element) .Case([&](VariableElement *var) { variables.emplace_back(var); }) .Case([&](CustomDirective *ele) { - for (FormatElement *arg : ele->getArguments()) + for (FormatElement *arg : ele->getElements()) collect(arg, variables); }) .Case([&](OptionalElement *ele) { @@ -3774,7 +3774,7 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc, return success(); // Verify each child as being valid in an optional group. They are all // potential anchors if the custom directive was marked as one. - for (FormatElement *child : ele->getArguments()) { + for (FormatElement *child : ele->getElements()) { if (isa(child)) continue; if (failed(verifyOptionalGroupElement(loc, child, /*isAnchor=*/true)))