Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ class AsmPrinter {
/// be printed.
virtual LogicalResult printAlias(Type type);

/// Check if the given type has an alias that will be printed in the future.
/// Returns false if the type has an alias that's currently being printed or
/// has already been printed. This can aid printing mutually recursive types.
virtual bool hasFutureAlias(Type type) const;

/// Print the given string as a keyword, or a quoted and escaped string if it
/// has any special or non-printable characters in it.
virtual void printKeywordOrString(StringRef keyword);
Expand Down
40 changes: 37 additions & 3 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ class AsmPrinter::Impl {
/// be printed.
LogicalResult printAlias(Type type);

/// Check if the given type has an alias that will be printed in the future.
/// Returns false if the type has an alias that's currently being printed or
/// has already been printed. This can aid printing mutually recursive types.
bool hasFutureAlias(Type type) const;

/// Print the given location to the stream. If `allowAlias` is true, this
/// allows for the internal location to use an attribute alias.
void printLocation(LocationAttr loc, bool allowAlias = false);
Expand Down Expand Up @@ -547,8 +552,13 @@ class SymbolAlias {
bool isDeferrable : 1;

public:
/// Used to distinguish aliases that are currently being or have previously
/// been printed from those that will be printed in the future, which can aid
/// printing mutually recursive types.
bool hasStartedPrinting = false;

/// Used to avoid printing incomplete aliases for recursive types.
bool isPrinted = false;
bool hasFinishedPrinting = false;
};

/// This class represents a utility that initializes the set of attribute and
Expand Down Expand Up @@ -774,6 +784,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
initializer.visit(type);
return success();
}
bool hasFutureAlias(Type) const override { return false; }

/// Consider the given location to be printed for an alias.
void printOptionalLocationSpecifier(Location loc) override {
Expand Down Expand Up @@ -948,6 +959,7 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
printType(type);
return success();
}
bool hasFutureAlias(Type) const override { return false; }

/// Record the alias result of a child element.
void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
Expand Down Expand Up @@ -1182,6 +1194,11 @@ class AliasState {
/// Returns success if an alias was printed, failure otherwise.
LogicalResult getAlias(Type ty, raw_ostream &os) const;

/// Check if the given type has an alias that will be printed in the future.
/// Returns false if the type has an alias that's currently being printed or
/// has already been printed. This can aid printing mutually recursive types.
bool hasFutureAlias(Type ty) const;

/// Print all of the referenced aliases that can not be resolved in a deferred
/// manner.
void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
Expand Down Expand Up @@ -1226,13 +1243,20 @@ LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
if (it == attrTypeToAlias.end())
return failure();
if (!it->second.isPrinted)
if (!it->second.hasFinishedPrinting)
return failure();

it->second.print(os);
return success();
}

bool AliasState::hasFutureAlias(Type ty) const {
const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
if (it == attrTypeToAlias.end())
return false;
return !it->second.hasStartedPrinting;
}

void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
bool isDeferred) {
auto filterFn = [=](const auto &aliasIt) {
Expand All @@ -1245,8 +1269,9 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,

if (alias.isTypeAlias()) {
Type type = Type::getFromOpaquePointer(opaqueSymbol);
alias.hasStartedPrinting = true;
p.printTypeImpl(type);
alias.isPrinted = true;
alias.hasFinishedPrinting = true;
} else {
// TODO: Support nested aliases in mutable attributes.
Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
Expand Down Expand Up @@ -2234,6 +2259,10 @@ LogicalResult AsmPrinter::Impl::printAlias(Type type) {
return state.getAliasState().getAlias(type, os);
}

bool AsmPrinter::Impl::hasFutureAlias(Type type) const {
return state.getAliasState().hasFutureAlias(type);
}

void AsmPrinter::Impl::printAttribute(Attribute attr,
AttrTypeElision typeElision) {
if (!attr) {
Expand Down Expand Up @@ -2832,6 +2861,11 @@ LogicalResult AsmPrinter::printAlias(Type type) {
return impl->printAlias(type);
}

bool AsmPrinter::hasFutureAlias(Type type) const {
assert(impl && "expected AsmPrinter::hasFutureAlias to be overridden");
return impl->hasFutureAlias(type);
}

void AsmPrinter::printAttributeWithoutType(Attribute attr) {
assert(impl &&
"expected AsmPrinter::printAttributeWithoutType to be overriden");
Expand Down
17 changes: 13 additions & 4 deletions mlir/test/IR/recursive-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

// CHECK: !testrec = !test.test_rec<type_to_alias, test_rec<type_to_alias>>
// CHECK: ![[$NAME:.*]] = !test.test_rec_alias<name, !test.test_rec_alias<name>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>
// CHECK: ![[$NAME5:.*]] = !test.test_rec_alias<name5, !test.test_rec_alias<name3>>
// CHECK: ![[$NAME7:.*]] = !test.test_rec_alias<name7, !test.test_rec_alias<name6>>
// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, !name5_>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, !name4_>
// CHECK: ![[$NAME4:.*]] = !test.test_rec_alias<name4, ![[$NAME5]]>
// CHECK: ![[$NAME6:.*]] = !test.test_rec_alias<name6, ![[$NAME7]]>
// CHECK: ![[$NAME3:.*]] = !test.test_rec_alias<name3, ![[$NAME4]]>

// CHECK-LABEL: @roundtrip
func.func @roundtrip() {
Expand All @@ -28,13 +30,20 @@ func.func @roundtrip() {
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name2, tuple<!test.test_rec_alias<name2>, i32>>

// Mutual recursion.
// Mutual recursion with types fully spelled out.
// CHECK: () -> ![[$NAME3]]
// CHECK: () -> ![[$NAME4]]
// CHECK: () -> ![[$NAME5]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name4, !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4>>>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name5, !test.test_rec_alias<name3, !test.test_rec_alias<name4, !test.test_rec_alias<name5>>>>

// Mutual recursion with incomplete types.
// CHECK: () -> ![[$NAME6]]
// CHECK: () -> ![[$NAME7]]
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name6, !test.test_rec_alias<name7>>
"test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias<name7, !test.test_rec_alias<name6>>

return
}

Expand Down
6 changes: 5 additions & 1 deletion mlir/test/lib/Dialect/Test/TestTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,10 @@ Type TestRecursiveAliasType::parse(AsmParser &parser) {
return rec;
}

// Allow incomplete definitions that can be completed later.
if (succeeded(parser.parseGreater()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (succeeded(parser.parseGreater()))
if (succeeded(parser.parseOptionalGreater()))

Surprised this doesn't cause any issues I must admit.

return rec;

// Otherwise, parse the body and update the type.
if (failed(parser.parseComma()))
return Type();
Expand All @@ -525,7 +529,7 @@ void TestRecursiveAliasType::print(AsmPrinter &printer) const {
printer.tryStartCyclicPrint(*this);

printer << "<" << getName();
if (succeeded(cyclicPrint)) {
if (succeeded(cyclicPrint) && !printer.hasFutureAlias(*this)) {
printer << ", ";
printer << getBody();
}
Expand Down
Loading