diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 6c1ff4d0e5e6b..b4506d58386ec 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -202,7 +202,8 @@ class AsmPrinter { /// special or non-printable characters in it. virtual void printSymbolName(StringRef symbolRef); - /// Print a handle to the given dialect resource. + /// Print a handle to the given dialect resource. The handle key is quoted and + /// escaped if it has any special or non-printable characters in it. virtual void printResourceHandle(const AsmDialectResourceHandle &resource); /// Print an optional arrow followed by a type list. diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h index d5b72d63813a4..1f8fbfdd93568 100644 --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -248,13 +248,7 @@ class AsmParserImpl : public BaseT { /// Parses a quoted string token if present. ParseResult parseOptionalString(std::string *string) override { - if (!parser.getToken().is(Token::string)) - return failure(); - - if (string) - *string = parser.getToken().getStringValue(); - parser.consumeToken(); - return success(); + return parser.parseOptionalString(string); } /// Parses a Base64 encoded string of bytes. @@ -355,13 +349,7 @@ class AsmParserImpl : public BaseT { /// Parse a keyword, if present, into 'keyword'. ParseResult parseOptionalKeyword(StringRef *keyword) override { - // Check that the current token is a keyword. - if (!parser.isCurrentTokenAKeyword()) - return failure(); - - *keyword = parser.getTokenSpelling(); - parser.consumeToken(); - return success(); + return parser.parseOptionalKeyword(keyword); } /// Parse a keyword if it is one of the 'allowedKeywords'. @@ -387,13 +375,7 @@ class AsmParserImpl : public BaseT { /// Parse an optional keyword or string and set instance into 'result'.` ParseResult parseOptionalKeywordOrString(std::string *result) override { - StringRef keyword; - if (succeeded(parseOptionalKeyword(&keyword))) { - *result = keyword.str(); - return success(); - } - - return parseOptionalString(result); + return parser.parseOptionalKeywordOrString(result); } //===--------------------------------------------------------------------===// @@ -514,7 +496,7 @@ class AsmParserImpl : public BaseT { return parser.emitError() << "dialect '" << dialect->getNamespace() << "' does not expect resource handles"; } - StringRef resourceName; + std::string resourceName; return parser.parseResourceHandle(interface, resourceName); } diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp index e3db248164672..fe6d41932f6f8 100644 --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -271,6 +271,17 @@ ParseResult Parser::parseToken(Token::Kind expectedToken, return emitWrongTokenError(message); } +/// Parses a quoted string token if present. +ParseResult Parser::parseOptionalString(std::string *string) { + if (!getToken().is(Token::string)) + return failure(); + + if (string) + *string = getToken().getStringValue(); + consumeToken(); + return success(); +} + /// Parse an optional integer value from the stream. OptionalParseResult Parser::parseOptionalInteger(APInt &result) { // Parse `false` and `true` keywords as 0 and 1 respectively. @@ -412,15 +423,25 @@ ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { return success(); } +ParseResult Parser::parseOptionalKeywordOrString(std::string *result) { + StringRef keyword; + if (succeeded(parseOptionalKeyword(&keyword))) { + *result = keyword.str(); + return success(); + } + + return parseOptionalString(result); +} + //===----------------------------------------------------------------------===// // Resource Parsing FailureOr Parser::parseResourceHandle(const OpAsmDialectInterface *dialect, - StringRef &name) { + std::string &name) { assert(dialect && "expected valid dialect interface"); SMLoc nameLoc = getToken().getLoc(); - if (failed(parseOptionalKeyword(&name))) + if (failed(parseOptionalKeywordOrString(&name))) return emitError("expected identifier key for 'resource' entry"); auto &resources = getState().symbols.dialectResources; @@ -451,7 +472,7 @@ Parser::parseResourceHandle(Dialect *dialect) { return emitError() << "dialect '" << dialect->getNamespace() << "' does not expect resource handles"; } - StringRef resourceName; + std::string resourceName; return parseResourceHandle(interface, resourceName); } @@ -2530,8 +2551,8 @@ class TopLevelOperationParser : public Parser { /// textual format. class ParsedResourceEntry : public AsmParsedResourceEntry { public: - ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p) - : key(key), keyLoc(keyLoc), value(value), p(p) {} + ParsedResourceEntry(std::string key, SMLoc keyLoc, Token value, Parser &p) + : key(std::move(key)), keyLoc(keyLoc), value(value), p(p) {} ~ParsedResourceEntry() override = default; StringRef getKey() const final { return key; } @@ -2607,7 +2628,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry { } private: - StringRef key; + std::string key; SMLoc keyLoc; Token value; Parser &p; @@ -2736,7 +2757,7 @@ ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() { return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { // Parse the name of the resource entry. SMLoc keyLoc = getToken().getLoc(); - StringRef key; + std::string key; if (failed(parseResourceHandle(handler, key)) || parseToken(Token::colon, "expected ':'")) return failure(); @@ -2763,8 +2784,8 @@ ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() { return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { // Parse the name of the resource entry. SMLoc keyLoc = getToken().getLoc(); - StringRef key; - if (failed(parseOptionalKeyword(&key))) + std::string key; + if (failed(parseOptionalKeywordOrString(&key))) return emitError( "expected identifier key for 'external_resources' entry"); if (parseToken(Token::colon, "expected ':'")) diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h index 37670bd789fec..1b8aa7c9dce6f 100644 --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -146,6 +146,9 @@ class Parser { /// output a diagnostic and return failure. ParseResult parseToken(Token::Kind expectedToken, const Twine &message); + /// Parses a quoted string token if present. + ParseResult parseOptionalString(std::string *string); + /// Parse an optional integer value from the stream. OptionalParseResult parseOptionalInteger(APInt &result); @@ -171,13 +174,16 @@ class Parser { /// Parse a keyword, if present, into 'keyword'. ParseResult parseOptionalKeyword(StringRef *keyword); + /// Parse an optional keyword or string and set instance into 'result'.` + ParseResult parseOptionalKeywordOrString(std::string *result); + //===--------------------------------------------------------------------===// // Resource Parsing //===--------------------------------------------------------------------===// /// Parse a handle to a dialect resource within the assembly format. FailureOr - parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name); + parseResourceHandle(const OpAsmDialectInterface *dialect, std::string &name); FailureOr parseResourceHandle(Dialect *dialect); //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 61b90bc9b0a7b..7c2c7f0875e2a 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2146,13 +2146,6 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { os << ')'; } -void AsmPrinter::Impl::printResourceHandle( - const AsmDialectResourceHandle &resource) { - auto *interface = cast(resource.getDialect()); - os << interface->getResourceKey(resource); - state.getDialectResources()[resource.getDialect()].insert(resource); -} - /// Returns true if the given dialect symbol data is simple enough to print in /// the pretty form. This is essentially when the symbol takes the form: /// identifier (`<` body `>`)? @@ -2237,6 +2230,13 @@ static void printElidedElementsAttr(raw_ostream &os) { os << R"(dense_resource<__elided__>)"; } +void AsmPrinter::Impl::printResourceHandle( + const AsmDialectResourceHandle &resource) { + auto *interface = cast(resource.getDialect()); + ::printKeywordOrString(interface->getResourceKey(resource), os); + state.getDialectResources()[resource.getDialect()].insert(resource); +} + LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { return state.getAliasState().getAlias(attr, os); } @@ -3331,41 +3331,41 @@ void OperationPrinter::printResourceFileMetadata( auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { checkAddMetadataDict(); - auto printFormatting = [&]() { - // Emit the top-level resource entry if we haven't yet. - if (!std::exchange(hadResource, true)) { - if (needResourceComma) - os << "," << newLine; - os << " " << dictName << "_resources: {" << newLine; - } - // Emit the parent resource entry if we haven't yet. - if (!std::exchange(hadEntry, true)) { - if (needEntryComma) - os << "," << newLine; - os << " " << name << ": {" << newLine; - } else { - os << "," << newLine; - } - }; - + std::string resourceStr; + auto printResourceStr = [&](raw_ostream &os) { os << resourceStr; }; std::optional charLimit = printerFlags.getLargeResourceStringLimit(); if (charLimit.has_value()) { - std::string resourceStr; llvm::raw_string_ostream ss(resourceStr); valueFn(ss); - // Only print entry if it's string is small enough + // Only print entry if its string is small enough. if (resourceStr.size() > charLimit.value()) return; - printFormatting(); - os << " " << key << ": " << resourceStr; + // Don't recompute resourceStr when valueFn is called below. + valueFn = printResourceStr; + } + + // Emit the top-level resource entry if we haven't yet. + if (!std::exchange(hadResource, true)) { + if (needResourceComma) + os << "," << newLine; + os << " " << dictName << "_resources: {" << newLine; + } + // Emit the parent resource entry if we haven't yet. + if (!std::exchange(hadEntry, true)) { + if (needEntryComma) + os << "," << newLine; + os << " " << name << ": {" << newLine; } else { - printFormatting(); - os << " " << key << ": "; - valueFn(os); + os << "," << newLine; } + os << " "; + ::printKeywordOrString(key, os); + os << ": "; + // Call printResourceStr or original valueFn, depending on charLimit. + valueFn(os); }; ResourceBuilder entryBuilder(printFn); provider.buildResources(op, providerArgs..., entryBuilder); diff --git a/mlir/test/Bytecode/resources.mlir b/mlir/test/Bytecode/resources.mlir index 33ed01d20fa0c..3ef220e890042 100644 --- a/mlir/test/Bytecode/resources.mlir +++ b/mlir/test/Bytecode/resources.mlir @@ -4,21 +4,21 @@ module @TestDialectResources attributes { // CHECK: bytecode.test = dense_resource : tensor<2xui32> // CHECK: bytecode.test2 = dense_resource : tensor<4xf64> - // CHECK: bytecode.test3 = dense_resource : tensor<4xf64> + // CHECK: bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64> bytecode.test = dense_resource : tensor<2xui32>, bytecode.test2 = dense_resource : tensor<4xf64>, - bytecode.test3 = dense_resource : tensor<4xf64> + bytecode.test3 = dense_resource<"resource\09two"> : tensor<4xf64> } {} // CHECK: builtin: { // CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000" -// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: "resource\09two": "0x08000000010000000000000002000000000000000300000000000000" {-# dialect_resources: { builtin: { resource: "0x08000000010000000000000002000000000000000300000000000000", - resource_2: "0x08000000010000000000000002000000000000000300000000000000" + "resource\09two": "0x08000000010000000000000002000000000000000300000000000000" } } #-} diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir index adba97994ff60..44cefc3aa1616 100644 --- a/mlir/test/IR/dense-resource-elements-attr.mlir +++ b/mlir/test/IR/dense-resource-elements-attr.mlir @@ -11,3 +11,18 @@ } } #-} + +// ----- + +// DenseResourceElementsHandle key blob\-"one" is quoted and escaped. +// CHECK: attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16> +"test.user_op"() {attr = dense_resource<"blob\\-\22one\22"> : tensor<2xi16>} : () -> () + +{-# + dialect_resources: { + builtin: { + // CHECK: "blob\\-\22one\22": "0x0200000001000200" + "blob\\-\22one\22": "0x0200000001000200" + } + } +#-} diff --git a/mlir/test/IR/pretty-resources-print.mlir b/mlir/test/IR/pretty-resources-print.mlir index 625967fcb7603..297c83bbb1389 100644 --- a/mlir/test/IR/pretty-resources-print.mlir +++ b/mlir/test/IR/pretty-resources-print.mlir @@ -12,7 +12,7 @@ // CHECK: {-# // CHECK-NEXT: external_resources: { // CHECK-NEXT: external: { -// CHECK-NEXT: bool: true, +// CHECK-NEXT: "backslash\\tab\09": true, // CHECK-NEXT: string: "\22string\22" // CHECK-NEXT: }, // CHECK-NEXT: other_stuff: { @@ -31,8 +31,8 @@ external_resources: { external: { blob: "0x08000000010000000000000002000000000000000300000000000000", - bool: true, - string: "\"string\"" // with escape characters + "backslash\\tab\09": true, // quoted key with escape characters + string: "\"string\"" // string with escape characters }, other_stuff: { bool: true