diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index e2472eea8a371..b62b6706df967 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -265,8 +265,11 @@ class AsmPrinter { /// Attempts to start a cyclic printing region for `attrOrType`. /// A cyclic printing region starts with this call and ends with the /// destruction of the returned `CyclicPrintReset`. During this time, - /// calling `tryStartCyclicPrint` with the same attribute in any printer - /// will lead to returning failure. + /// calling `tryStartCyclicPrint` with the same attribute or type in any + /// printer will lead to returning failure. Additionally, if the printer + /// knows a complete definition of the attribute or type will be emitted in + /// the future, it'll also return failure to permit abbreviated definitions + /// to be used wherever possible. /// /// This makes it possible to break infinite recursions when trying to print /// cyclic attributes or types by printing only immutable parameters if nested @@ -278,6 +281,8 @@ class AsmPrinter { AttrOrTypeT> || std::is_base_of_v, AttrOrTypeT>, "Only mutable attributes or types can be cyclic"); + if (hasFutureAlias(attrOrType.getAsOpaquePointer())) + return failure(); if (failed(pushCyclicPrinting(attrOrType.getAsOpaquePointer()))) return failure(); return CyclicPrintReset(this); @@ -299,6 +304,12 @@ class AsmPrinter { /// in reverse order of all successful `pushCyclicPrinting`. virtual void popCyclicPrinting(); + /// Check if the given attribute or type (in the form of a type erased + /// pointer) will be printed as an alias in the future. Returns false if the + /// type has an alias that's currently being printed or has already been + /// printed. This enables cyclic print checking for mutual recursion. + virtual bool hasFutureAlias(const void *opaquePointer) const; + private: AsmPrinter(const AsmPrinter &) = delete; void operator=(const AsmPrinter &) = delete; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index a728425f2ec6b..a62443932ddd1 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -449,6 +449,8 @@ class AsmPrinter::Impl { void popCyclicPrinting(); + bool hasFutureAlias(const void *opaquePointer) const; + void printDimensionList(ArrayRef shape); protected: @@ -547,8 +549,13 @@ class SymbolAlias { bool isDeferrable : 1; public: + /// Used to distinguish aliases that are currently being or have previously + /// been printed from those that will be printed in the future, which can aid + /// printing mutually recursive types. + bool hasStartedPrinting = false; + /// Used to avoid printing incomplete aliases for recursive types. - bool isPrinted = false; + bool hasFinishedPrinting = false; }; /// This class represents a utility that initializes the set of attribute and @@ -974,6 +981,8 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); } + bool hasFutureAlias(const void *) const override { return false; } + /// Stack of potentially cyclic mutable attributes or type currently being /// printed. SetVector cyclicPrintingStack; @@ -1182,6 +1191,12 @@ class AliasState { /// Returns success if an alias was printed, failure otherwise. LogicalResult getAlias(Type ty, raw_ostream &os) const; + /// Check if the given attribute or type (in the form of a type erased + /// pointer) will be printed as an alias in the future. Returns false if the + /// type has an alias that's currently being printed or has already been + /// printed. This enables cyclic print checking for mutual recursion. + bool hasFutureAlias(const void *opaquePointer) const; + /// Print all of the referenced aliases that can not be resolved in a deferred /// manner. void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) { @@ -1226,13 +1241,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const { const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer()); if (it == attrTypeToAlias.end()) return failure(); - if (!it->second.isPrinted) + if (!it->second.hasFinishedPrinting) return failure(); it->second.print(os); return success(); } +bool AliasState::hasFutureAlias(const void *opaquePointer) const { + const auto *it = attrTypeToAlias.find(opaquePointer); + if (it == attrTypeToAlias.end()) + return false; + return !it->second.hasStartedPrinting; +} + void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, bool isDeferred) { auto filterFn = [=](const auto &aliasIt) { @@ -1245,8 +1267,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine, if (alias.isTypeAlias()) { Type type = Type::getFromOpaquePointer(opaqueSymbol); + alias.hasStartedPrinting = true; p.printTypeImpl(type); - alias.isPrinted = true; + alias.hasFinishedPrinting = true; } else { // TODO: Support nested aliases in mutable attributes. Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); @@ -2791,6 +2814,10 @@ LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) { void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); } +bool AsmPrinter::Impl::hasFutureAlias(const void *opaquePointer) const { + return state.getAliasState().hasFutureAlias(opaquePointer); +} + void AsmPrinter::Impl::printDimensionList(ArrayRef shape) { detail::printDimensionList(os, shape); } @@ -2870,6 +2897,10 @@ LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) { void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); } +bool AsmPrinter::hasFutureAlias(const void *opaquePointer) const { + return impl->hasFutureAlias(opaquePointer); +} + //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir index 42aecb41d998d..c5d0cd09b220f 100644 --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -2,10 +2,12 @@ // CHECK: !testrec = !test.test_rec> // CHECK: ![[$NAME:.*]] = !test.test_rec_alias> -// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias>>> +// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias> +// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias> // CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> -// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias -// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias +// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias // CHECK-LABEL: @roundtrip func.func @roundtrip() { @@ -28,13 +30,20 @@ func.func @roundtrip() { "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> - // Mutual recursion. + // Mutual recursion with types fully spelled out. // CHECK: () -> ![[$NAME3]] // CHECK: () -> ![[$NAME4]] // CHECK: () -> ![[$NAME5]] "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias>>> + + // Mutual recursion with incomplete types. + // CHECK: () -> ![[$NAME6]] + // CHECK: () -> ![[$NAME7]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + return } diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 1593b6d7d7534..48f519faba40f 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) { return rec; } + // Allow incomplete definitions that can be completed later. + if (succeeded(parser.parseGreater())) + return rec; + // Otherwise, parse the body and update the type. if (failed(parser.parseComma())) return Type();