Skip to content
3 changes: 2 additions & 1 deletion mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ class AsmPrinter {
/// special or non-printable characters in it.
virtual void printSymbolName(StringRef symbolRef);

/// Print a handle to the given dialect resource.
/// Print a handle to the given dialect resource. The handle key is quoted and
/// escaped if it has any special or non-printable characters in it.
virtual void printResourceHandle(const AsmDialectResourceHandle &resource);

/// Print an optional arrow followed by a type list.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ class AsmParserImpl : public BaseT {
return parser.emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
}
StringRef resourceName;
std::string resourceName;
return parser.parseResourceHandle(interface, resourceName);
}

Expand Down
35 changes: 26 additions & 9 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,32 @@ ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
return success();
}

ParseResult Parser::parseOptionalKeywordOrString(std::string *result) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the version of this in AsmParserImpl to call into this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

I did the same for the related parseOptionalKeyword and parserOptionalString methods

StringRef keyword;
if (succeeded(parseOptionalKeyword(&keyword))) {
*result = keyword.str();
return success();
}

// Parse a quoted string token if present.
if (!getToken().is(Token::string))
return failure();

if (result)
*result = getToken().getStringValue();
consumeToken();
return success();
}

//===----------------------------------------------------------------------===//
// Resource Parsing

FailureOr<AsmDialectResourceHandle>
Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
StringRef &name) {
std::string &name) {
assert(dialect && "expected valid dialect interface");
SMLoc nameLoc = getToken().getLoc();
if (failed(parseOptionalKeyword(&name)))
if (failed(parseOptionalKeywordOrString(&name)))
return emitError("expected identifier key for 'resource' entry");
auto &resources = getState().symbols.dialectResources;

Expand Down Expand Up @@ -451,7 +468,7 @@ Parser::parseResourceHandle(Dialect *dialect) {
return emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
}
StringRef resourceName;
std::string resourceName;
return parseResourceHandle(interface, resourceName);
}

Expand Down Expand Up @@ -2530,8 +2547,8 @@ class TopLevelOperationParser : public Parser {
/// textual format.
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p)
: key(key), keyLoc(keyLoc), value(value), p(p) {}
ParsedResourceEntry(std::string key, SMLoc keyLoc, Token value, Parser &p)
: key(std::move(key)), keyLoc(keyLoc), value(value), p(p) {}
~ParsedResourceEntry() override = default;

StringRef getKey() const final { return key; }
Expand Down Expand Up @@ -2607,7 +2624,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
}

private:
StringRef key;
std::string key;
SMLoc keyLoc;
Token value;
Parser &p;
Expand Down Expand Up @@ -2736,7 +2753,7 @@ ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() {
return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
// Parse the name of the resource entry.
SMLoc keyLoc = getToken().getLoc();
StringRef key;
std::string key;
if (failed(parseResourceHandle(handler, key)) ||
parseToken(Token::colon, "expected ':'"))
return failure();
Expand All @@ -2763,8 +2780,8 @@ ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() {
return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult {
// Parse the name of the resource entry.
SMLoc keyLoc = getToken().getLoc();
StringRef key;
if (failed(parseOptionalKeyword(&key)))
std::string key;
if (failed(parseOptionalKeywordOrString(&key)))
return emitError(
"expected identifier key for 'external_resources' entry");
if (parseToken(Token::colon, "expected ':'"))
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ class Parser {
/// Parse a keyword, if present, into 'keyword'.
ParseResult parseOptionalKeyword(StringRef *keyword);

/// Parse an optional keyword or string and set instance into 'result'.`
ParseResult parseOptionalKeywordOrString(std::string *result);

//===--------------------------------------------------------------------===//
// Resource Parsing
//===--------------------------------------------------------------------===//

/// Parse a handle to a dialect resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
parseResourceHandle(const OpAsmDialectInterface *dialect, std::string &name);
FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);

//===--------------------------------------------------------------------===//
Expand Down
62 changes: 31 additions & 31 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2146,13 +2146,6 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
os << ')';
}

void AsmPrinter::Impl::printResourceHandle(
const AsmDialectResourceHandle &resource) {
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
os << interface->getResourceKey(resource);
state.getDialectResources()[resource.getDialect()].insert(resource);
}

/// Returns true if the given dialect symbol data is simple enough to print in
/// the pretty form. This is essentially when the symbol takes the form:
/// identifier (`<` body `>`)?
Expand Down Expand Up @@ -2237,6 +2230,13 @@ static void printElidedElementsAttr(raw_ostream &os) {
os << R"(dense_resource<__elided__>)";
}

void AsmPrinter::Impl::printResourceHandle(
const AsmDialectResourceHandle &resource) {
auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
::printKeywordOrString(interface->getResourceKey(resource), os);
state.getDialectResources()[resource.getDialect()].insert(resource);
}

LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
return state.getAliasState().getAlias(attr, os);
}
Expand Down Expand Up @@ -3331,41 +3331,41 @@ void OperationPrinter::printResourceFileMetadata(
auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
checkAddMetadataDict();

auto printFormatting = [&]() {
// Emit the top-level resource entry if we haven't yet.
if (!std::exchange(hadResource, true)) {
if (needResourceComma)
os << "," << newLine;
os << " " << dictName << "_resources: {" << newLine;
}
// Emit the parent resource entry if we haven't yet.
if (!std::exchange(hadEntry, true)) {
if (needEntryComma)
os << "," << newLine;
os << " " << name << ": {" << newLine;
} else {
os << "," << newLine;
}
};

std::string resourceStr;
auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; };
std::optional<uint64_t> charLimit =
printerFlags.getLargeResourceStringLimit();
if (charLimit.has_value()) {
std::string resourceStr;
llvm::raw_string_ostream ss(resourceStr);
valueFn(ss);

// Only print entry if it's string is small enough
// Only print entry if its string is small enough.
if (resourceStr.size() > charLimit.value())
return;

printFormatting();
os << " " << key << ": " << resourceStr;
// Don't recompute resourceStr when valueFn is called below.
valueFn = printResourceStr;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the old implementation of printFn used the printFormatting closure to reduce code duplication across the two charLimit conditional branches, but there was still some code duplication to print the key which became worse with the new logic

therefore I refactored the implementation to instead “reprogram” valueFn in the case where resourceStr is materialized to check charLimit

}

// Emit the top-level resource entry if we haven't yet.
if (!std::exchange(hadResource, true)) {
if (needResourceComma)
os << "," << newLine;
os << " " << dictName << "_resources: {" << newLine;
}
// Emit the parent resource entry if we haven't yet.
if (!std::exchange(hadEntry, true)) {
if (needEntryComma)
os << "," << newLine;
os << " " << name << ": {" << newLine;
} else {
printFormatting();
os << " " << key << ": ";
valueFn(os);
os << "," << newLine;
}
os << " ";
::printKeywordOrString(key, os);
os << ": ";
// Call printResourceStr or original valueFn, depending on charLimit.
valueFn(os);
};
ResourceBuilder entryBuilder(printFn);
provider.buildResources(op, providerArgs..., entryBuilder);
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Bytecode/resources.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
module @TestDialectResources attributes {
// CHECK: bytecode.test = dense_resource<decl_resource> : tensor<2xui32>
// CHECK: bytecode.test2 = dense_resource<resource> : tensor<4xf64>
// CHECK: bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
// CHECK: bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
bytecode.test = dense_resource<decl_resource> : tensor<2xui32>,
bytecode.test2 = dense_resource<resource> : tensor<4xf64>,
bytecode.test3 = dense_resource<resource_2> : tensor<4xf64>
bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64>
} {}

// CHECK: builtin: {
// CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: "resource\09two": "0x08000000010000000000000002000000000000000300000000000000"

{-#
dialect_resources: {
builtin: {
resource: "0x08000000010000000000000002000000000000000300000000000000",
resource_2: "0x08000000010000000000000002000000000000000300000000000000"
"resource\09two": "0x08000000010000000000000002000000000000000300000000000000"
}
}
#-}
15 changes: 15 additions & 0 deletions mlir/test/IR/dense-resource-elements-attr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,18 @@
}
}
#-}

// -----

// DenseResourceElementsHandle key blob\-"one" is quoted and escaped.
// CHECK: attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>
"test.user_op"() {attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>} : () -> ()

{-#
dialect_resources: {
builtin: {
// CHECK: "blob\\-\22one\22": "0x0200000001000200"
"blob\\-\22one\22": "0x0200000001000200"
}
}
#-}
6 changes: 3 additions & 3 deletions mlir/test/IR/pretty-resources-print.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// CHECK: {-#
// CHECK-NEXT: external_resources: {
// CHECK-NEXT: external: {
// CHECK-NEXT: bool: true,
// CHECK-NEXT: "backslash\\tab\09": true,
// CHECK-NEXT: string: "\22string\22"
// CHECK-NEXT: },
// CHECK-NEXT: other_stuff: {
Expand All @@ -31,8 +31,8 @@
external_resources: {
external: {
blob: "0x08000000010000000000000002000000000000000300000000000000",
bool: true,
string: "\"string\"" // with escape characters
"backslash\\tab\09": true, // quoted key with escape characters
string: "\"string\"" // string with escape characters
},
other_stuff: {
bool: true
Expand Down
Loading