From 9aeefa8eef91b9966bd5b0226bd70e59c783e0fc Mon Sep 17 00:00:00 2001 From: Shoaib Meenai Date: Fri, 11 Oct 2024 22:02:29 -0700 Subject: [PATCH 1/2] [mlir] Support better printing for mutually recursive types For mutually recursive types, the current way types are printed forces the earlier type alias to include a full definition of the later type. Many recursive types (e.g. structs in ClangIR) have a notion of an incomplete type definition, and by exposing a simple hook in the AsmPrinter to determine whether a type will be printed in the future, we can enable dialects to use incomplete type definitions (which they know will be completed later) when printing mutually recursive types instead, which makes them much easier to read. --- mlir/include/mlir/IR/OpImplementation.h | 5 +++ mlir/lib/IR/AsmPrinter.cpp | 40 ++++++++++++++++++++++-- mlir/test/IR/recursive-type.mlir | 17 +++++++--- mlir/test/lib/Dialect/Test/TestTypes.cpp | 6 +++- 4 files changed, 60 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index e2472eea8a371..2b79727d8c932 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -188,6 +188,11 @@ class AsmPrinter { /// be printed. virtual LogicalResult printAlias(Type type); + /// Check if the given type has an alias that will be printed in the future. + /// Returns false if the type has an alias that's currently being printed or + /// has already been printed. This can aid printing mutually recursive types. + virtual bool hasFutureAlias(Type type) const; + /// Print the given string as a keyword, or a quoted and escaped string if it /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a728425f2ec6b..2c5e0f5b92a4e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -430,6 +430,11 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Type type); + /// Check if the given type has an alias that will be printed in the future. + /// Returns false if the type has an alias that's currently being printed or + /// has already been printed. This can aid printing mutually recursive types. + bool hasFutureAlias(Type type) const; + /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -547,8 +552,13 @@ class SymbolAlias { bool isDeferrable : 1; public: + /// Used to distinguish aliases that are currently being or have previously + /// been printed from those that will be printed in the future, which can aid + /// printing mutually recursive types. + bool hasStartedPrinting = false; + /// Used to avoid printing incomplete aliases for recursive types. - bool isPrinted = false; + bool hasFinishedPrinting = false; }; /// This class represents a utility that initializes the set of attribute and @@ -774,6 +784,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { initializer.visit(type); return success(); } + bool hasFutureAlias(Type) const override { return false; } /// Consider the given location to be printed for an alias. void printOptionalLocationSpecifier(Location loc) override { @@ -948,6 +959,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { printType(type); return success(); } + bool hasFutureAlias(Type) const override { return false; } /// Record the alias result of a child element. void recordAliasResult(std::pair aliasDepthAndIndex) { @@ -1182,6 +1194,11 @@ class AliasState { /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Type ty, raw_ostream &os) const; + /// Check if the given type has an alias that will be printed in the future. + /// Returns false if the type has an alias that's currently being printed or + /// has already been printed. This can aid printing mutually recursive types. + bool hasFutureAlias(Type ty) const; + /// Print all of the referenced aliases that can not be resolved in a deferred /// manner. void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { @@ -1226,13 +1243,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); if (it == attrTypeToAlias.end()) return failure(); - if (!it->second.isPrinted) + if (!it->second.hasFinishedPrinting) return failure(); it->second.print(os); return success(); } +bool AliasState::hasFutureAlias(Type ty) const { + const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); + if (it == attrTypeToAlias.end()) + return false; + return !it->second.hasStartedPrinting; +} + void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred) { auto filterFn = [=](const auto &aliasIt) { @@ -1245,8 +1269,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, if (alias.isTypeAlias()) { Type type = Type::getFromOpaquePointer(opaqueSymbol); + alias.hasStartedPrinting = true; p.printTypeImpl(type); - alias.isPrinted = true; + alias.hasFinishedPrinting = true; } else { // TODO: Support nested aliases in mutable attributes. Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); @@ -2234,6 +2259,10 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) { return state.getAliasState().getAlias(type, os); } +bool AsmPrinter::Impl::hasFutureAlias(Type type) const { + return state.getAliasState().hasFutureAlias(type); +} + void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -2832,6 +2861,11 @@ LogicalResult AsmPrinter::printAlias(Type type) { return impl->printAlias(type); } +bool AsmPrinter::hasFutureAlias(Type type) const { + assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden"); + return impl->hasFutureAlias(type); +} + void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir index 42aecb41d998d..c5d0cd09b220f 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -2,10 +2,12 @@ // CHECK: !testrec = !test.test_rec> // CHECK: ![[$NAME:.*]] = !test.test_rec_alias> -// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias>>> +// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias> +// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias> // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> -// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias -// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias // CHECK-LABEL: @roundtrip func.func @roundtrip() { @@ -28,13 +30,20 @@ func.func @roundtrip() { "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> - // Mutual recursion. + // Mutual recursion with types fully spelled out. // CHECK: () -> ![[$NAME3]] // CHECK: () -> ![[$NAME4]] // CHECK: () -> ![[$NAME5]] "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> + + // Mutual recursion with incomplete types. + // CHECK: () -> ![[$NAME6]] + // CHECK: () -> ![[$NAME7]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + return } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 1593b6d7d7534..94576544ba306 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) { return rec; } + // Allow incomplete definitions that can be completed later. + if (succeeded(parser.parseGreater())) + return rec; + // Otherwise, parse the body and update the type. if (failed(parser.parseComma())) return Type(); @@ -525,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const { printer.tryStartCyclicPrint(*this); printer << "<" << getName(); - if (succeeded(cyclicPrint)) { + if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) { printer << ", "; printer << getBody(); } From 75d62775d20ac038edb8cc7a7ffb5e54fdc73026 Mon Sep 17 00:00:00 2001 From: Shoaib Meenai Date: Mon, 14 Oct 2024 16:38:23 -0700 Subject: [PATCH 2/2] Incorporate into tryStartCyclicPrint --- mlir/include/mlir/IR/OpImplementation.h | 20 ++++++++---- mlir/lib/IR/AsmPrinter.cpp | 41 +++++++++++------------- mlir/test/lib/Dialect/Test/TestTypes.cpp | 2 +- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 2b79727d8c932..b62b6706df967 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -188,11 +188,6 @@ class AsmPrinter { /// be printed. virtual LogicalResult printAlias(Type type); - /// Check if the given type has an alias that will be printed in the future. - /// Returns false if the type has an alias that's currently being printed or - /// has already been printed. This can aid printing mutually recursive types. - virtual bool hasFutureAlias(Type type) const; - /// Print the given string as a keyword, or a quoted and escaped string if it /// has any special or non-printable characters in it. virtual void printKeywordOrString(StringRef keyword); @@ -270,8 +265,11 @@ class AsmPrinter { /// Attempts to start a cyclic printing region for `attrOrType`. /// A cyclic printing region starts with this call and ends with the /// destruction of the returned `CyclicPrintReset`. During this time, - /// calling `tryStartCyclicPrint` with the same attribute in any printer - /// will lead to returning failure. + /// calling `tryStartCyclicPrint` with the same attribute or type in any + /// printer will lead to returning failure. Additionally, if the printer + /// knows a complete definition of the attribute or type will be emitted in + /// the future, it'll also return failure to permit abbreviated definitions + /// to be used wherever possible. /// /// This makes it possible to break infinite recursions when trying to print /// cyclic attributes or types by printing only immutable parameters if nested @@ -283,6 +281,8 @@ class AsmPrinter { AttrOrTypeT> || std::is_base_of_v, AttrOrTypeT>, "Only mutable attributes or types can be cyclic"); + if (hasFutureAlias(attrOrType.getAsOpaquePointer())) + return failure(); if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer()))) return failure(); return CyclicPrintReset(this); @@ -304,6 +304,12 @@ class AsmPrinter { /// in reverse order of all successful `pushCyclicPrinting`. virtual void popCyclicPrinting(); + /// Check if the given attribute or type (in the form of a type erased + /// pointer) will be printed as an alias in the future. Returns false if the + /// type has an alias that's currently being printed or has already been + /// printed. This enables cyclic print checking for mutual recursion. + virtual bool hasFutureAlias(const void *opaquePointer) const; + private: AsmPrinter(const AsmPrinter &) = delete; void operator=(const AsmPrinter &) = delete; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 2c5e0f5b92a4e..a62443932ddd1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -430,11 +430,6 @@ class AsmPrinter::Impl { /// be printed. LogicalResult printAlias(Type type); - /// Check if the given type has an alias that will be printed in the future. - /// Returns false if the type has an alias that's currently being printed or - /// has already been printed. This can aid printing mutually recursive types. - bool hasFutureAlias(Type type) const; - /// Print the given location to the stream. If `allowAlias` is true, this /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); @@ -454,6 +449,8 @@ class AsmPrinter::Impl { void popCyclicPrinting(); + bool hasFutureAlias(const void *opaquePointer) const; + void printDimensionList(ArrayRef shape); protected: @@ -784,7 +781,6 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { initializer.visit(type); return success(); } - bool hasFutureAlias(Type) const override { return false; } /// Consider the given location to be printed for an alias. void printOptionalLocationSpecifier(Location loc) override { @@ -959,7 +955,6 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { printType(type); return success(); } - bool hasFutureAlias(Type) const override { return false; } /// Record the alias result of a child element. void recordAliasResult(std::pair aliasDepthAndIndex) { @@ -986,6 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } + bool hasFutureAlias(const void *) const override { return false; } + /// Stack of potentially cyclic mutable attributes or type currently being /// printed. SetVector cyclicPrintingStack; @@ -1194,10 +1191,11 @@ class AliasState { /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Type ty, raw_ostream &os) const; - /// Check if the given type has an alias that will be printed in the future. - /// Returns false if the type has an alias that's currently being printed or - /// has already been printed. This can aid printing mutually recursive types. - bool hasFutureAlias(Type ty) const; + /// Check if the given attribute or type (in the form of a type erased + /// pointer) will be printed as an alias in the future. Returns false if the + /// type has an alias that's currently being printed or has already been + /// printed. This enables cyclic print checking for mutual recursion. + bool hasFutureAlias(const void *opaquePointer) const; /// Print all of the referenced aliases that can not be resolved in a deferred /// manner. @@ -1250,8 +1248,8 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { return success(); } -bool AliasState::hasFutureAlias(Type ty) const { - const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); +bool AliasState::hasFutureAlias(const void *opaquePointer) const { + const auto *it = attrTypeToAlias.find(opaquePointer); if (it == attrTypeToAlias.end()) return false; return !it->second.hasStartedPrinting; @@ -2259,10 +2257,6 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) { return state.getAliasState().getAlias(type, os); } -bool AsmPrinter::Impl::hasFutureAlias(Type type) const { - return state.getAliasState().hasFutureAlias(type); -} - void AsmPrinter::Impl::printAttribute(Attribute attr, AttrTypeElision typeElision) { if (!attr) { @@ -2820,6 +2814,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } +bool AsmPrinter::Impl::hasFutureAlias(const void *opaquePointer) const { + return state.getAliasState().hasFutureAlias(opaquePointer); +} + void AsmPrinter::Impl::printDimensionList(ArrayRef shape) { detail::printDimensionList(os, shape); } @@ -2861,11 +2859,6 @@ LogicalResult AsmPrinter::printAlias(Type type) { return impl->printAlias(type); } -bool AsmPrinter::hasFutureAlias(Type type) const { - assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden"); - return impl->hasFutureAlias(type); -} - void AsmPrinter::printAttributeWithoutType(Attribute attr) { assert(impl && "expected AsmPrinter::printAttributeWithoutType to be overriden"); @@ -2904,6 +2897,10 @@ LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } +bool AsmPrinter::hasFutureAlias(const void *opaquePointer) const { + return impl->hasFutureAlias(opaquePointer); +} + //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 94576544ba306..48f519faba40f 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -529,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const { printer.tryStartCyclicPrint(*this); printer << "<" << getName(); - if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) { + if (succeeded(cyclicPrint)) { printer << ", "; printer << getBody(); }