Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,15 @@ class OpAsmPrinter : public AsmPrinter {
/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
/// specific well known attributes, commonly used if the attribute value is
/// printed some other way (like as a fixed operand).
/// printed some other way (like as a fixed operand). If printNamedAttrFn is
/// provided the default printing can be overridden for a named attribute.
/// printNamedAttrFn is passed a NamedAttribute, if it prints the attribute
/// it returns `success()`, otherwise, it returns `failure()` which indicates
/// that generic printing should be used.
virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) = 0;
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) = 0;

/// If the specified operation has attributes, print out an attribute
/// dictionary prefixed with 'attributes'.
Expand Down Expand Up @@ -1116,8 +1122,17 @@ class AsmParser {
return parseResult;
}

/// Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
/// Parse a named dictionary into 'result' if it is present. If
/// parseNamedAttrFn is provided the default parsing can be overridden for a
/// named attribute. parseNamedAttrFn is passed the name of an attribute, if
/// it can parse the attribute it returns the parsed attribute, otherwise, it
/// returns `failure()` which indicates that generic parsing should be used.
/// Note: Returning a null Attribute from parseNamedAttrFn indicates a parser
/// error.
virtual ParseResult parseOptionalAttrDict(
NamedAttrList &result,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
nullptr) = 0;

/// Parse a named dictionary into 'result' if the `attributes` keyword is
/// present.
Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,13 @@ class AsmParserImpl : public BaseT {
}

/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
ParseResult parseOptionalAttrDict(
NamedAttrList &result,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn =
nullptr) override {
if (parser.getToken().isNot(Token::l_brace))
return success();
return parser.parseAttributeDict(result);
return parser.parseAttributeDict(result, parseNamedAttrFn);
}

/// Parse a named dictionary into 'result' if the `attributes` keyword is
Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ OptionalParseResult Parser::parseOptionalAttribute(SymbolRefAttr &result,
/// | `{` attribute-entry (`,` attribute-entry)* `}`
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
ParseResult Parser::parseAttributeDict(
NamedAttrList &attributes,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn) {
llvm::SmallDenseSet<StringAttr> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
Expand Down Expand Up @@ -329,7 +331,17 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
}

auto attr = parseAttribute();
Attribute attr = nullptr;
FailureOr<Attribute> customParsedAttribute;
// Try to parse with `printNamedAttrFn` callback.
if (parseNamedAttrFn &&
succeeded(customParsedAttribute = parseNamedAttrFn(*nameId))) {
attr = *customParsedAttribute;
} else {
// Otherwise, use generic attribute parser.
attr = parseAttribute();
}

if (!attr)
return failure();
attributes.push_back({*nameId, attr});
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ class Parser {
}

/// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes);
ParseResult parseAttributeDict(
NamedAttrList &attributes,
function_ref<FailureOr<Attribute>(StringRef)> parseNamedAttrFn = nullptr);

/// Parse a distinct attribute.
Attribute parseDistinctAttr(Type type);
Expand Down
45 changes: 30 additions & 15 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,13 @@ class AsmPrinter::Impl {
void printDimensionList(ArrayRef<int64_t> shape);

protected:
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {},
bool withKeyword = false);
void printNamedAttribute(NamedAttribute attr);
void printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs = {},
bool withKeyword = false,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printNamedAttribute(
NamedAttribute attr,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn = nullptr);
void printTrailingLocation(Location loc, bool allowAlias = true);
void printLocationInternal(LocationAttr loc, bool pretty = false,
bool isTopLevel = false);
Expand Down Expand Up @@ -780,9 +783,10 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// Print the given set of attributes with names not included within
/// 'elidedAttrs'.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
if (attrs.empty())
return;
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) override {
(void)printNamedAttrFn;
if (elidedAttrs.empty()) {
for (const NamedAttribute &attr : attrs)
printAttribute(attr.getValue());
Expand Down Expand Up @@ -2687,9 +2691,10 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}

void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs,
bool withKeyword) {
void AsmPrinter::Impl::printOptionalAttrDict(
ArrayRef<NamedAttribute> attrs, ArrayRef<StringRef> elidedAttrs,
bool withKeyword,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
Expand All @@ -2702,8 +2707,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,

// Otherwise, print them all out in braces.
os << " {";
interleaveComma(filteredAttrs,
[&](NamedAttribute attr) { printNamedAttribute(attr); });
interleaveComma(filteredAttrs, [&](NamedAttribute attr) {
printNamedAttribute(attr, printNamedAttrFn);
});
os << '}';
};

Expand All @@ -2720,7 +2726,9 @@ void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
if (!filteredAttrs.empty())
printFilteredAttributesFn(filteredAttrs);
}
void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
void AsmPrinter::Impl::printNamedAttribute(
NamedAttribute attr,
function_ref<LogicalResult(NamedAttribute)> printNamedAttrFn) {
// Print the name without quotes if possible.
::printKeywordOrString(attr.getName().strref(), os);

Expand All @@ -2729,6 +2737,10 @@ void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
return;

os << " = ";
if (printNamedAttrFn && succeeded(printNamedAttrFn(attr))) {
/// If we print via the `printNamedAttrFn` callback skip printing.
return;
}
printAttribute(attr.getValue());
}

Expand Down Expand Up @@ -3149,8 +3161,11 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {

/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs);
ArrayRef<StringRef> elidedAttrs = {},
function_ref<LogicalResult(NamedAttribute)>
printNamedAttrFn = nullptr) override {
Impl::printOptionalAttrDict(attrs, elidedAttrs, /*withKeyword=*/false,
printNamedAttrFn);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
Expand Down