diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 16bc857ad3416..5b5ec841917e7 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -552,8 +552,11 @@ class SymbolAlias { /// Print this alias to the given stream. void print(raw_ostream &os) const { os << (isType ? "!" : "#") << name; - if (suffixIndex) + if (suffixIndex) { + if (isdigit(name.back())) + os << '_'; os << suffixIndex; + } } /// Returns true if this is a type alias. @@ -659,6 +662,12 @@ class AliasInitializer { template void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred); + /// Uniques the given alias name within the printer by generating name index + /// used as alias name suffix. + static unsigned + uniqueAliasNameIndex(StringRef alias, llvm::StringMap &nameCounts, + llvm::StringSet &usedAliases); + /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. static void initializeAliases( @@ -1025,8 +1034,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// the string needs to be modified in any way, the provided buffer is used to /// store the new copy, static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, - StringRef allowedPunctChars = "$._-", - bool allowTrailingDigit = true) { + StringRef allowedPunctChars = "$._-") { assert(!name.empty() && "Shouldn't have an empty name here"); auto validChar = [&](char ch) { @@ -1053,14 +1061,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, return buffer; } - // If the name ends with a trailing digit, add a '_' to avoid potential - // conflicts with autogenerated ID's. - if (!allowTrailingDigit && isdigit(name.back())) { - copyNameToBuffer(); - buffer.push_back('_'); - return buffer; - } - // Check to see that the name consists of only valid identifier characters. for (char ch : name) { if (!validChar(ch)) { @@ -1073,6 +1073,36 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer, return name; } +unsigned AliasInitializer::uniqueAliasNameIndex( + StringRef alias, llvm::StringMap &nameCounts, + llvm::StringSet &usedAliases) { + if (!usedAliases.count(alias)) { + usedAliases.insert(alias); + // 0 is not printed in SymbolAlias. + return 0; + } + // Otherwise, we had a conflict - probe until we find a unique name. + SmallString<64> probeAlias(alias); + // alias with trailing digit will be printed as _N + if (isdigit(alias.back())) + probeAlias.push_back('_'); + // nameCounts start from 1 because 0 is not printed in SymbolAlias. + if (nameCounts[probeAlias] == 0) + nameCounts[probeAlias] = 1; + // This is guaranteed to terminate (and usually in a single iteration) + // because it generates new names by incrementing nameCounts. + while (true) { + unsigned nameIndex = nameCounts[probeAlias]++; + probeAlias += llvm::utostr(nameIndex); + if (!usedAliases.count(probeAlias)) { + usedAliases.insert(probeAlias); + return nameIndex; + } + // Reset probeAlias to the original alias for the next iteration. + probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0); + } +} + /// Given a collection of aliases and symbols, initialize a mapping from a /// symbol to a given alias. void AliasInitializer::initializeAliases( @@ -1084,12 +1114,17 @@ void AliasInitializer::initializeAliases( return lhs.second < rhs.second; }); + // This keeps track of all of the non-numeric names that are in flight, + // allowing us to check for duplicates. + llvm::BumpPtrAllocator usedAliasAllocator; + llvm::StringSet usedAliases(usedAliasAllocator); + llvm::StringMap nameCounts; for (auto &[symbol, aliasInfo] : unprocessedAliases) { if (!aliasInfo.alias) continue; StringRef alias = *aliasInfo.alias; - unsigned nameIndex = nameCounts[alias]++; + unsigned nameIndex = uniqueAliasNameIndex(alias, nameCounts, usedAliases); symbolToAlias.insert( {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType, aliasInfo.canBeDeferred)}); @@ -1196,8 +1231,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias, SmallString<16> tempBuffer; StringRef name = - sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", - /*allowTrailingDigit=*/false); + sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-"); name = name.copy(aliasAllocator); alias = InProgressAliasInfo(name); } diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir index e878d862076c9..37eff99d3cebf 100644 --- a/mlir/test/IR/print-attr-type-aliases.mlir +++ b/mlir/test/IR/print-attr-type-aliases.mlir @@ -5,7 +5,7 @@ // CHECK-DAG: #test2Ealias = "alias_test:dot_in_name" "test.op"() {alias_test = "alias_test:dot_in_name"} : () -> () -// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit" +// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit" "test.op"() {alias_test = "alias_test:trailing_digit"} : () -> () // CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit" @@ -14,9 +14,15 @@ // CHECK-DAG: #_25test = "alias_test:prefixed_symbol" "test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> () -// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a" -// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b" -"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> () +// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_b" +// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_c" +// CHECK-DAG: #test_alias_conflict0_ = "alias_test:trailing_digit_conflict_d" +// CHECK-DAG: #test_alias_conflict0_1_1 = "alias_test:trailing_digit_conflict_e" +// CHECK-DAG: #test_alias_conflict0_1_2 = "alias_test:trailing_digit_conflict_f" +// CHECK-DAG: #test_alias_conflict0_1_ = "alias_test:trailing_digit_conflict_g" +// CHECK-DAG: #test_alias_conflict0_1_1_1 = "alias_test:trailing_digit_conflict_h" +// CHECK-DAG: #test_alias_conflict0_1_1_1_1 = "alias_test:trailing_digit_conflict_a" +"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d", "alias_test:trailing_digit_conflict_e", "alias_test:trailing_digit_conflict_f", "alias_test:trailing_digit_conflict_g", "alias_test:trailing_digit_conflict_h"]} : () -> () // CHECK-DAG: !tuple = tuple "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple) @@ -28,8 +34,8 @@ // CHECK-DAG: tensor<32xf32, #test_encoding> "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding"> -// CHECK-DAG: !test_ui8_ = !test.int -// CHECK-DAG: tensor<32x!test_ui8_> +// CHECK-DAG: !test_ui8 = !test.int +// CHECK-DAG: tensor<32x!test_ui8> "test.op"() : () -> tensor<32x!test.int> // CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested") @@ -47,8 +53,8 @@ // ----- // Ensure self type parameters get considered for aliases. -// CHECK: !test_ui8_ = !test.int -// CHECK: #test.attr_with_self_type_param : !test_ui8_ +// CHECK: !test_ui8 = !test.int +// CHECK: #test.attr_with_self_type_param : !test_ui8 "test.op"() {alias_test = #test.attr_with_self_type_param : !test.int } : () -> () // ----- diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir index 42aecb41d998d..b8111d9601e48 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -4,8 +4,8 @@ // CHECK: ![[$NAME:.*]] = !test.test_rec_alias> // CHECK: ![[$NAME5:.*]] = !test.test_rec_alias>>> // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> -// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias -// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias // CHECK-LABEL: @roundtrip func.func @roundtrip() { diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 64add8cef3698..01ae245e06e5a 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -187,12 +187,24 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { StringSwitch>(strAttr.getValue()) .Case("alias_test:dot_in_name", StringRef("test.alias")) .Case("alias_test:trailing_digit", StringRef("test_alias0")) - .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) - .Case("alias_test:prefixed_symbol", StringRef("%test")) - .Case("alias_test:sanitize_conflict_a", + .Case("alias_test:trailing_digit_conflict_a", + StringRef("test_alias_conflict0_1_1_1")) + .Case("alias_test:trailing_digit_conflict_b", + StringRef("test_alias_conflict0")) + .Case("alias_test:trailing_digit_conflict_c", StringRef("test_alias_conflict0")) - .Case("alias_test:sanitize_conflict_b", + .Case("alias_test:trailing_digit_conflict_d", StringRef("test_alias_conflict0_")) + .Case("alias_test:trailing_digit_conflict_e", + StringRef("test_alias_conflict0_1")) + .Case("alias_test:trailing_digit_conflict_f", + StringRef("test_alias_conflict0_1")) + .Case("alias_test:trailing_digit_conflict_g", + StringRef("test_alias_conflict0_1_")) + .Case("alias_test:trailing_digit_conflict_h", + StringRef("test_alias_conflict0_1_1")) + .Case("alias_test:prefixed_digit", StringRef("0_test_alias")) + .Case("alias_test:prefixed_symbol", StringRef("%test")) .Case("alias_test:tensor_encoding", StringRef("test_encoding")) .Default(std::nullopt); if (!aliasName)