diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index ae412c7227f8e..5891cbffc9542 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter { /// If the specified operation has attributes, print out an attribute /// dictionary with their values. elidedAttrs allows the client to ignore /// specific well known attributes, commonly used if the attribute value is - /// printed some other way (like as a fixed operand). + /// printed some other way (like as a fixed operand). If printNamedAttrFn is + /// provided the default printing can be overridden for a named attribute. + /// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute + /// it returns `success()`, otherwise, it returns `failure()` which indicates + /// that generic printing should be used. virtual void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) = 0; + ArrayRef elidedAttrs = {}, + function_ref + printNamedAttrFn = nullptr) = 0; /// If the specified operation has attributes, print out an attribute /// dictionary prefixed with 'attributes'. @@ -1116,8 +1122,17 @@ class AsmParser { return parseResult; } - /// Parse a named dictionary into 'result' if it is present. - virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; + /// Parse a named dictionary into 'result' if it is present. If + /// parseNamedAttrFn is provided the default parsing can be overridden for a + /// named attribute. parseNamedAttrFn is passed the name of an attribute, if + /// it can parse the attribute it returns the parsed attribute, otherwise, it + /// returns `failure()` which indicates that generic parsing should be used. + /// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser + /// error. + virtual ParseResult parseOptionalAttrDict( + NamedAttrList &result, + function_ref(StringRef)> parseNamedAttrFn = + nullptr) = 0; /// Parse a named dictionary into 'result' if the `attributes` keyword is /// present. diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index b12687833e3fd..808b2ca282f64 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT { } /// Parse a named dictionary into 'result' if it is present. - ParseResult parseOptionalAttrDict(NamedAttrList &result) override { + ParseResult parseOptionalAttrDict( + NamedAttrList &result, + function_ref(StringRef)> parseNamedAttrFn = + nullptr) override { if (parser.getToken().isNot(Token::l_brace)) return success(); - return parser.parseAttributeDict(result); + return parser.parseAttributeDict(result, parseNamedAttrFn); } /// Parse a named dictionary into 'result' if the `attributes` keyword is diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index efa65e49abc33..b687d822e7cb7 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result, /// | `{` attribute-entry (`,` attribute-entry)* `}` /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value /// -ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { +ParseResult Parser::parseAttributeDict( + NamedAttrList &attributes, + function_ref(StringRef)> parseNamedAttrFn) { llvm::SmallDenseSet seenKeys; auto parseElt = [&]() -> ParseResult { // The name of an attribute can either be a bare identifier, or a string. @@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { return success(); } - auto attr = parseAttribute(); + Attribute attr = nullptr; + FailureOr customParsedAttribute; + // Try to parse with `printNamedAttrFn` callback. + if (parseNamedAttrFn && + succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) { + attr = *customParsedAttribute; + } else { + // Otherwise, use generic attribute parser. + attr = parseAttribute(); + } + if (!attr) return failure(); attributes.push_back({*nameId, attr}); diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index 4caab499e1a0e..d5d90f391fd39 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -256,7 +256,9 @@ class Parser { } /// Parse an attribute dictionary. - ParseResult parseAttributeDict(NamedAttrList &attributes); + ParseResult parseAttributeDict( + NamedAttrList &attributes, + function_ref(StringRef)> parseNamedAttrFn = nullptr); /// Parse a distinct attribute. Attribute parseDistinctAttr(Type type); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 02acc8c3f4659..d931fb62cd7bd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -452,10 +452,13 @@ class AsmPrinter::Impl { void printDimensionList(ArrayRef shape); protected: - void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}, - bool withKeyword = false); - void printNamedAttribute(NamedAttribute attr); + void printOptionalAttrDict( + ArrayRef attrs, ArrayRef elidedAttrs = {}, + bool withKeyword = false, + function_ref printNamedAttrFn = nullptr); + void printNamedAttribute( + NamedAttribute attr, + function_ref printNamedAttrFn = nullptr); void printTrailingLocation(Location loc, bool allowAlias = true); void printLocationInternal(LocationAttr loc, bool pretty = false, bool isTopLevel = false); @@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { /// Print the given set of attributes with names not included within /// 'elidedAttrs'. void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - if (attrs.empty()) - return; + ArrayRef elidedAttrs = {}, + function_ref + printNamedAttrFn = nullptr) override { + (void)printNamedAttrFn; if (elidedAttrs.empty()) { for (const NamedAttribute &attr : attrs) printAttribute(attr.getValue()); @@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) { .Default([&](Type type) { return printDialectType(type); }); } -void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs, - bool withKeyword) { +void AsmPrinter::Impl::printOptionalAttrDict( + ArrayRef attrs, ArrayRef elidedAttrs, + bool withKeyword, + function_ref printNamedAttrFn) { // If there are no attributes, then there is nothing to be done. if (attrs.empty()) return; @@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, // Otherwise, print them all out in braces. os << " {"; - interleaveComma(filteredAttrs, - [&](NamedAttribute attr) { printNamedAttribute(attr); }); + interleaveComma(filteredAttrs, [&](NamedAttribute attr) { + printNamedAttribute(attr, printNamedAttrFn); + }); os << '}'; }; @@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef attrs, if (!filteredAttrs.empty()) printFilteredAttributesFn(filteredAttrs); } -void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { +void AsmPrinter::Impl::printNamedAttribute( + NamedAttribute attr, + function_ref printNamedAttrFn) { // Print the name without quotes if possible. ::printKeywordOrString(attr.getName().strref(), os); @@ -2729,6 +2737,11 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { return; os << " = "; + if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) { + /// If we print via the `printNamedAttrFn` callback, skip the generic + /// attribute printing (i.e. the call to `printAttribute`). + return; + } printAttribute(attr.getValue()); } @@ -3149,8 +3162,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef attrs, - ArrayRef elidedAttrs = {}) override { - Impl::printOptionalAttrDict(attrs, elidedAttrs); + ArrayRef elidedAttrs = {}, + function_ref + printNamedAttrFn = nullptr) override { + Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false, + printNamedAttrFn); } void printOptionalAttrDictWithKeyword( ArrayRef attrs, diff --git a/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir new file mode 100644 index 0000000000000..5c62430edde22 --- /dev/null +++ b/mlir/test/IR/custom-attr-syntax-in-attr-dict.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s | FileCheck %s --check-prefix=CHECK-ROUNDTRIP +// RUN: mlir-opt %s -mlir-print-op-generic | FileCheck %s --check-prefix=CHECK-GENERIC-SYNTAX + +/// This file tetss that "custom_dense_array" (which is a DenseArrayAttribute +/// stored within the attr-dict) is parsed and printed with the "pretty" array +/// syntax (i.e. `[1, 2, 3, 4]`), rather than with the generic dense array +/// syntax (`array`). +/// +/// This is done by injecting custom parsing and printing callbacks into +/// parseOptionalAttrDict() and printOptionalAttrDict(). + +func.func @custom_attr_dict_syntax() { + // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]} + // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array}> : () -> () + test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4]} + + // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]} + // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array}> {another_attr = "foo"} : () -> () + test.custom_attr_parse_and_print_in_attr_dict {another_attr = "foo", custom_dense_array = [1, 2, 3, 4]} + + // CHECK-ROUNDTRIP: test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]} + // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array}> {default_array = [1, 2, 3, 4]} : () -> () + test.custom_attr_parse_and_print_in_attr_dict {custom_dense_array = [1, 2, 3, 4], default_array = [1, 2, 3, 4]} + + // CHECK-ROUND-TRIP: test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array, custom_dense_array = [1, 2, 3, 4]} + // CHECK-GENERIC-SYNTAX: "test.custom_attr_parse_and_print_in_attr_dict"() <{custom_dense_array = array}> {default_dense_array = array} : () -> () + test.custom_attr_parse_and_print_in_attr_dict {default_dense_array = array, custom_dense_array = [1, 2, 3, 4]} + + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index fbaa102d3e33c..616db9cda2ad8 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -141,6 +141,37 @@ void AffineScopeOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } +//===----------------------------------------------------------------------===// +// CustomAttrParseAndPrintInAttrDict +//===----------------------------------------------------------------------===// + +ParseResult CustomAttrParseAndPrintInAttrDict::parse(OpAsmParser &parser, + OperationState &result) { + return parser.parseOptionalAttrDict( + result.attributes, [&](StringRef name) -> FailureOr { + // Override the parsing for the "custom_dense_array" attribute in the + // attr-dict. Rather than parsing it as array, parse + // it as [0, 1, 2, ...] (i.e. using the standard array syntax). + if (name != getCustomDenseArrayAttrName(result.name)) + return failure(); + return DenseI64ArrayAttr::parse(parser, {}); + }); +} + +void CustomAttrParseAndPrintInAttrDict::print(OpAsmPrinter &p) { + p.printOptionalAttrDict( + (*this)->getAttrs(), {}, + [&](NamedAttribute attrDictNamedAttribute) -> LogicalResult { + // Override the printing for the "custom_dense_array" attribute. Rather + // than printing it as array, print it as + // [0, 1, 2, ...] (i.e. using standard array syntax). + if (attrDictNamedAttribute.getName() != getCustomDenseArrayAttrName()) + return failure(); + cast(attrDictNamedAttribute.getValue()).print(p); + return success(); + }); +} + //===----------------------------------------------------------------------===// // TestRemoveOpWithInnerOps //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2b55bff3538d3..1c8cdc3fef74e 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2083,6 +2083,16 @@ def OptionalCustomAttrOp : TEST_Op<"optional_custom_attr"> { }]; } +//===----------------------------------------------------------------------===// +// Test overriding attribute parsing/printing in the attr-dict via callbacks +// on parseOptionalAttrDict() and printOptionalAttrDict(). + +def CustomAttrParseAndPrintInAttrDict : TEST_Op<"custom_attr_parse_and_print_in_attr_dict"> +{ + let arguments = (ins DenseI64ArrayAttr:$custom_dense_array); + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Test OpAsmInterface.