Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,36 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
let assemblyFormat = "`<` `nested` `=` $nested `>`";
}

// Nested attributes for reproducing.
def InternalAttr : Test_Attr<"Internal"> {
let mnemonic = "internal";

let parameters = (ins
"int64_t":$key,
"int64_t":$value
);

let assemblyFormat = "`<` struct(params) `>`";
}

def ExternalAttr : Test_Attr<"External"> {
let mnemonic = "external";

let parameters = (ins InternalAttr:$internal);

let assemblyFormat = "`<` struct(params) `>`";
}

def ExternalArrayAttr : Test_Attr<"ExternalArray"> {
let mnemonic = "external_array";

let parameters = (ins
ArrayRefParameter<"InternalAttr">:$internals
);

let assemblyFormat = "`<` `[` struct(params) `]` `>`";
}

// An attribute testing AttributeSelfTypeParameter.
def AttrWithSelfTypeParam
: Test_Attr<"AttrWithSelfTypeParam", [TypedAttrInterface]> {
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/mlir-tblgen/op-format.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,27 @@ module attributes {test.someAttr = #test.cmpnd_nested_outer<i <42 <1, !test.smpl

//-----

// CHECK: module attributes {test.internal = #test.internal<key = 8, value = 9>} {
// CHECK-NEXT: }
module attributes {test.internal = #test.internal<key = 8, value = 9>} {
}

//-----

// CHECK: module attributes {test.external = #test.external<internal = <key = 1, value = 2>>} {
// CHECK-NEXT: }
module attributes {test.external = #test.external<internal = <key = 1, value = 2>>} {
}

//-----

// CHECK: module attributes {test.external_array = #test.external_array<[internals = <key = 1, value = 2>, <key = 8, value = 9>]>} {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the nested qualified print here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nested qualified only prints the first qualified mnemonic.

Before:

module attributes {test.external_array = #test.external_array<[internals = #test.internal<key = 1, value = 2>, #test.internal<key = 8, value = 9>]>} {
}

After:

module attributes {test.external_array = #test.external_array<[internals = <key = 1, value = 2>, <key = 8, value = 9>]>} {
}

// CHECK-NEXT: }
module attributes {test.external_array = #test.external_array<[internals = <key = 1, value = 2>, <key = 8, value = 9>]>} {
}

//-----

// CHECK: test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>
test.format_cpmd_nested_attr nested <i <42 <1, !test.smpla, [5, 6]>>>

Expand Down
21 changes: 21 additions & 0 deletions mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,25 @@ class StructDirective

} // namespace

/// Return true if the provided parameter should always be printed in qualified
/// form (i.e., with dialect/type prefixes). Attribute and type parameters fall
/// into this category to avoid ambiguity when nested within structured
/// properties.
static bool shouldPrintQualified(ParameterElement *param) {
const AttrOrTypeParameter &parameter = param->getParam();
StringRef cppType = parameter.getCppType();
if (!cppType.contains("Attr") && !cppType.contains("Type"))
return false;

if (parameter.getPrinter())
return false;

if (cppType.contains("mlir::Attribute") || cppType.contains("mlir::Type"))
return true;

return false;
}

//===----------------------------------------------------------------------===//
// Format Strings
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -872,6 +891,8 @@ void DefFormat::genCommaSeparatedPrinter(
}
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
os << "_firstPrinted = false;\n";
if (param && shouldPrintQualified(param))
param->setShouldBeQualified();
extra(arg);
shouldEmitSpace = false;
lastWasPunctuation = true;
Expand Down