-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Allow trailing digit for alias in AsmPrinter #127993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Allow trailing digit for alias in AsmPrinter #127993
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Hongren Zheng (ZenithalHourlyRate) ChangesWhen generating aliases from MotivationThere are two reasons to motivate the change from the old behavior to the proposed behavior
func.func @<!-- -->add(%ct: !ct_L0_) -> !ct_L0_
%ct_0 = bgv.add %ct, %ct : (!ct_L0_, !ct_L0_) -> !ct_L0_
%ct_1 = bgv.add %ct_0, %ct_0 : (!ct_L0_, !ct_L0_) -> !ct_L0_
%ct_2 = bgv.add %ct_1, %ct_1 : (!ct_L0_, !ct_L0_) -> !ct_L0_
return %ct_2 : !ct_L0_
}Which aesthetically would be better if we have
Conflict detection!test.type<a = 3> // suggest !name0
!test.type<a = 4> // suggest !name0
!test.another<b = 3> // suggest !name0_
!test.another<b = 4> // suggest !name0_The conflict detection is based on In the original way, the first two will get sanitized to In the current way, the Full diff: https://github.com/llvm/llvm-project/pull/127993.diff 4 Files Affected:
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f22d4f37a813..8044a1c8507e8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -547,8 +547,11 @@ class SymbolAlias {
/// Print this alias to the given stream.
void print(raw_ostream &os) const {
os << (isType ? "!" : "#") << name;
- if (suffixIndex)
+ if (suffixIndex) {
+ if (isdigit(name.back()))
+ os << '_';
os << suffixIndex;
+ }
}
/// Returns true if this is a type alias.
@@ -1020,8 +1023,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
/// the string needs to be modified in any way, the provided buffer is used to
/// store the new copy,
static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
- StringRef allowedPunctChars = "$._-",
- bool allowTrailingDigit = true) {
+ StringRef allowedPunctChars = "$._-") {
assert(!name.empty() && "Shouldn't have an empty name here");
auto validChar = [&](char ch) {
@@ -1048,14 +1050,6 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
return buffer;
}
- // If the name ends with a trailing digit, add a '_' to avoid potential
- // conflicts with autogenerated ID's.
- if (!allowTrailingDigit && isdigit(name.back())) {
- copyNameToBuffer();
- buffer.push_back('_');
- return buffer;
- }
-
// Check to see that the name consists of only valid identifier characters.
for (char ch : name) {
if (!validChar(ch)) {
@@ -1084,7 +1078,16 @@ void AliasInitializer::initializeAliases(
if (!aliasInfo.alias)
continue;
StringRef alias = *aliasInfo.alias;
- unsigned nameIndex = nameCounts[alias]++;
+ unsigned nameIndex;
+ // If the alias ends with a digit, we need to pretend as if it has trailing
+ // underscore to get a unique nameIndex.
+ if (isdigit(alias.back())) {
+ SmallString<16> aliasBuffer(alias);
+ aliasBuffer.push_back('_');
+ nameIndex = nameCounts[aliasBuffer]++;
+ } else {
+ nameIndex = nameCounts[alias]++;
+ }
symbolToAlias.insert(
{symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
aliasInfo.canBeDeferred)});
@@ -1191,8 +1194,7 @@ void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
SmallString<16> tempBuffer;
StringRef name =
- sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
- /*allowTrailingDigit=*/false);
+ sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-");
name = name.copy(aliasAllocator);
alias = InProgressAliasInfo(name);
}
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index e878d862076c9..97cda270e2c2e 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -5,7 +5,7 @@
// CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
"test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
-// CHECK-DAG: #test_alias0_ = "alias_test:trailing_digit"
+// CHECK-DAG: #test_alias0 = "alias_test:trailing_digit"
"test.op"() {alias_test = "alias_test:trailing_digit"} : () -> ()
// CHECK-DAG: #_0_test_alias = "alias_test:prefixed_digit"
@@ -14,12 +14,11 @@
// CHECK-DAG: #_25test = "alias_test:prefixed_symbol"
"test.op"() {alias_test = "alias_test:prefixed_symbol"} : () -> ()
-// CHECK-DAG: #test_alias_conflict0_ = "alias_test:sanitize_conflict_a"
-// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:sanitize_conflict_b"
-"test.op"() {alias_test = ["alias_test:sanitize_conflict_a", "alias_test:sanitize_conflict_b"]} : () -> ()
-
-// CHECK-DAG: !tuple = tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>
-"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>)
+// CHECK-DAG: #test_alias_conflict0 = "alias_test:trailing_digit_conflict_a"
+// CHECK-DAG: #test_alias_conflict0_1 = "alias_test:trailing_digit_conflict_b"
+// CHECK-DAG: #test_alias_conflict0_2 = "alias_test:trailing_digit_conflict_c"
+// CHECK-DAG: #test_alias_conflict0_3 = "alias_test:trailing_digit_conflict_d"
+"test.op"() {alias_test = ["alias_test:trailing_digit_conflict_a", "alias_test:trailing_digit_conflict_b", "alias_test:trailing_digit_conflict_c", "alias_test:trailing_digit_conflict_d"]} : () -> ()
// CHECK-DAG: !test_tuple = tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>
"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple<!test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla, !test.smpla>)
@@ -28,8 +27,8 @@
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
-// CHECK-DAG: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK-DAG: tensor<32x!test_ui8_>
+// CHECK-DAG: !test_ui8 = !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8>
"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
// CHECK-DAG: #[[LOC_NESTED:.+]] = loc("nested")
@@ -47,8 +46,8 @@
// -----
// Ensure self type parameters get considered for aliases.
-// CHECK: !test_ui8_ = !test.int<unsigned, 8>
-// CHECK: #test.attr_with_self_type_param : !test_ui8_
+// CHECK: !test_ui8 = !test.int<unsigned, 8>
+// CHECK: #test.attr_with_self_type_param : !test_ui8
"test.op"() {alias_test = #test.attr_with_self_type_param : !test.int<unsigned, 8> } : () -> ()
// -----
diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir
index 42aecb41d998d..b8111d9601e48 100644
--- a/mlir/test/IR/recursive-type.mlir
+++ b/mlir/test/IR/recursive-type.mlir
@@ -4,8 +4,8 @@
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
-// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
-// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
+// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5>
+// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4>
// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
index 64add8cef3698..065692b98f219 100644
--- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -187,12 +187,16 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
.Case("alias_test:dot_in_name", StringRef("test.alias"))
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
- .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
- .Case("alias_test:prefixed_symbol", StringRef("%test"))
- .Case("alias_test:sanitize_conflict_a",
+ .Case("alias_test:trailing_digit_conflict_a",
+ StringRef("test_alias_conflict0"))
+ .Case("alias_test:trailing_digit_conflict_b",
StringRef("test_alias_conflict0"))
- .Case("alias_test:sanitize_conflict_b",
+ .Case("alias_test:trailing_digit_conflict_c",
StringRef("test_alias_conflict0_"))
+ .Case("alias_test:trailing_digit_conflict_d",
+ StringRef("test_alias_conflict0_"))
+ .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
+ .Case("alias_test:prefixed_symbol", StringRef("%test"))
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
.Default(std::nullopt);
if (!aliasName)
|
River707
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having trouble finding which of the tests covers the situation where a user defined alias conflicts with autogenerated ones, can you point me to which test is covering this case?
OK I realized for the following case it will conflict !test.type<a = 3> // suggest !name0
!test.type<a = 4> // suggest !name0
!test.another<b = 3> // suggest !name0_
!test.another<b = 4> // suggest !name0_1The result now would be I think such detection could be made by tarvering |
d951418 to
cf2b899
Compare
|
Added a detection logic for the printing alias with trailing digit as it. For example, for |
cf2b899 to
b21dc21
Compare
Checking only prefix is not sound, as the if So the approach by |
|
ping for review |
|
Ping for review |
fabe532 to
9f252c4
Compare
|
Comments addressed. |
mlir/lib/IR/AsmPrinter.cpp
Outdated
| usedAliases.insert(probeAlias); | ||
| break; | ||
| } | ||
| probeAlias.resize(alias.size() + isdigit(alias.back()) ? 1 : 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you hoist this size computation out of the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is part of the loop logic: every time the detection failed, probeAlias need to reset to alias and re-add the trailing llvm::utostr(nameIndex)
See also
llvm-project/mlir/lib/IR/AsmPrinter.cpp
Lines 1771 to 1785 in 87976ca
| } else { | |
| // Otherwise, we had a conflict - probe until we find a unique name. This | |
| // is guaranteed to terminate (and usually in a single iteration) because it | |
| // generates new names by incrementing nextConflictID. | |
| SmallString<64> probeName(name); | |
| probeName.push_back('_'); | |
| while (true) { | |
| probeName += llvm::utostr(nextConflictID++); | |
| if (!usedNames.count(probeName)) { | |
| name = probeName.str().copy(usedNameAllocator); | |
| break; | |
| } | |
| probeName.resize(name.size() + 1); | |
| } | |
| } |
9f252c4 to
b29650e
Compare
Also, these test cases were affected by the upstream PR llvm/llvm-project#127993
Also, these test cases were affected by the upstream PR llvm/llvm-project#127993
When generating aliases from
OpAsm{Dialect,Type,Attr}Interface, the result would be sanitized and if the alias provided by the interface has a trailing digit, AsmPrinter would attach an underscore to it to presumably prevent confliction.Motivation
There are two reasons to motivate the change from the old behavior to the proposed behavior
Which aesthetically would be better if we have
(!ct_L0, !ct_L0) -> !ct_L0_N, which can be similarly applied to alias name. See the IR above where the first one is called%ctand others are called%ct_N. SeeuniqueValueNamefor detail.Conflict detection
The conflict detection is based on
nameCountsininitializeAliases, whereIn the original way, the first two will get sanitized to
!name0_andinitializeAliascan assign unique id0, 1, 2, 3to them.In the current way, the
initializeAliasusesusedAliasesto track which name has been used, and use such information to generate a suffix id that will make the printed alias name unique.The result for the above example is
!name0, !name0_1, !name0_, !name0_2now.