Skip to content

Commit 39f57c0

Browse files
fix: add shouldBeQualifiedFlag and new printing utils for corresponding cases
1 parent 2338a54 commit 39f57c0

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/Twine.h"
2121
#include "llvm/Support/SMLoc.h"
2222
#include <optional>
23+
#include <type_traits>
2324

2425
namespace {
2526
// reference https://stackoverflow.com/a/16000226
@@ -203,6 +204,43 @@ class AsmPrinter {
203204
*this << attrOrType;
204205
}
205206

207+
/// Print the provided attribute or type ensuring that any dialect-specific
208+
/// prefixes (e.g. `#dialect.mnemonic`) are retained. This is used in
209+
/// contexts, such as structured property dictionaries, where the fully
210+
/// qualified form is required for disambiguation.
211+
template <typename AttrOrType,
212+
std::enable_if_t<std::is_convertible_v<AttrOrType, Attribute>> *
213+
sfinae = nullptr>
214+
void printQualifiedAttrOrType(AttrOrType attrOrType) {
215+
printAttribute(attrOrType);
216+
}
217+
218+
template <typename AttrOrType,
219+
std::enable_if_t<std::is_convertible_v<AttrOrType, Type> &&
220+
!std::is_convertible_v<AttrOrType, Attribute>> *
221+
sfinae = nullptr>
222+
void printQualifiedAttrOrType(AttrOrType attrOrType) {
223+
printType(attrOrType);
224+
}
225+
226+
template <typename ElementT,
227+
std::enable_if_t<std::is_convertible_v<ElementT, Attribute> ||
228+
std::is_convertible_v<ElementT, Type>> *sfinae =
229+
nullptr>
230+
void printQualifiedAttrOrType(ArrayRef<ElementT> attrOrTypes) {
231+
llvm::interleaveComma(attrOrTypes, getStream(), [&](ElementT element) {
232+
printQualifiedAttrOrType(element);
233+
});
234+
}
235+
236+
template <typename AttrOrType,
237+
std::enable_if_t<!std::is_convertible_v<AttrOrType, Attribute> &&
238+
!std::is_convertible_v<AttrOrType, Type>> *
239+
sfinae = nullptr>
240+
void printQualifiedAttrOrType(AttrOrType attrOrType) {
241+
*this << attrOrType;
242+
}
243+
206244
/// Print the given attribute without its type. The corresponding parser must
207245
/// provide a valid type for the attribute.
208246
virtual void printAttributeWithoutType(Attribute attr);

mlir/test/mlir-tblgen/op-format.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,14 +352,14 @@ module attributes {test.internal = #test.internal<key = 8, value = 9>} {
352352

353353
//-----
354354

355-
// CHECK: module attributes {test.external = #test.external<internal = <key = 1, value = 2>>} {
355+
// CHECK: module attributes {test.external = #test.external<internal = #test.internal<key = 1, value = 2>>} {
356356
// CHECK-NEXT: }
357357
module attributes {test.external = #test.external<internal = #test.internal<key = 1, value = 2>>} {
358358
}
359359

360360
//-----
361361

362-
// CHECK: module attributes {test.external_array = #test.external_array<[internals = <key = 1, value = 2>, <key = 8, value = 9>]>} {
362+
// CHECK: module attributes {test.external_array = #test.external_array<[internals = #test.internal<key = 1, value = 2>, #test.internal<key = 8, value = 9>]>} {
363363
// CHECK-NEXT: }
364364
module attributes {test.external_array = #test.external_array<[internals = #test.internal<key = 1, value = 2>, #test.internal<key = 8, value = 9>]>} {
365365
}

mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,15 @@ class StructDirective
141141

142142
} // namespace
143143

144+
/// Return true if the provided parameter should always be printed in qualified
145+
/// form (i.e., with dialect/type prefixes). Attribute and type parameters fall
146+
/// into this category to avoid ambiguity when nested within structured
147+
/// properties.
148+
static bool shouldPrintQualified(ParameterElement *param) {
149+
StringRef cppType = param->getParam().getCppType();
150+
return cppType.contains("Attr") || cppType.contains("Type");
151+
}
152+
144153
//===----------------------------------------------------------------------===//
145154
// Format Strings
146155
//===----------------------------------------------------------------------===//
@@ -155,7 +164,8 @@ static const char *const defaultParameterPrinter =
155164

156165
/// Qualified printer for attribute or type parameters: it does not elide
157166
/// dialect and mnemonic.
158-
static const char *const qualifiedParameterPrinter = "$_printer << $_self";
167+
static const char *const qualifiedParameterPrinter =
168+
"$_printer.printQualifiedAttrOrType($_self)";
159169

160170
/// Print an error when failing to parse an element.
161171
///
@@ -849,13 +859,13 @@ static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
849859
}
850860

851861
void DefFormat::genCommaSeparatedPrinter(
852-
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
862+
ArrayRef<FormatElement *> params, FmtContext &ctx, MethodBody &os,
853863
function_ref<void(FormatElement *)> extra) {
854864
// Emit a space if necessary, but only if the struct is present.
855865
if (shouldEmitSpace || !lastWasPunctuation) {
856-
bool allOptional = llvm::all_of(args, formatIsOptional);
866+
bool allOptional = llvm::all_of(params, formatIsOptional);
857867
if (allOptional)
858-
guardOnAnyOptional(ctx, os, args);
868+
guardOnAnyOptional(ctx, os, params);
859869
os << tgfmt("$_printer << ' ';\n", &ctx);
860870
if (allOptional)
861871
os.unindent() << "}\n";
@@ -864,14 +874,16 @@ void DefFormat::genCommaSeparatedPrinter(
864874
// The first printed element does not need to emit a comma.
865875
os << "{\n";
866876
os.indent() << "bool _firstPrinted = true;\n";
867-
for (FormatElement *arg : args) {
877+
for (FormatElement *arg : params) {
868878
ParameterElement *param = getEncapsulatedParameterElement(arg);
869879
if (param->isOptional()) {
870880
param->genPrintGuard(ctx, os << "if (") << ") {\n";
871881
os.indent();
872882
}
873883
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
874884
os << "_firstPrinted = false;\n";
885+
if (param && shouldPrintQualified(param))
886+
param->setShouldBeQualified();
875887
extra(arg);
876888
shouldEmitSpace = false;
877889
lastWasPunctuation = true;

0 commit comments

Comments
 (0)