diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d70aa346eaa1f..fb02149eebc8d 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -130,6 +130,16 @@ class AsmPrinter { /// Return the raw output stream used by this printer. virtual raw_ostream &getStream() const; + /// Print a newline and indent the printer to the start of the current + /// operation. + virtual void printNewline(); + + /// Increase indentation. + virtual void increaseIndent(); + + /// Decrease indentation. + virtual void decreaseIndent(); + /// Print the given floating point value in a stabilized form that can be /// roundtripped through the IR. This is the companion to the 'parseFloat' /// hook on the AsmParser. @@ -448,16 +458,6 @@ class OpAsmPrinter : public AsmPrinter { /// Print a loc(...) specifier if printing debug info is enabled. virtual void printOptionalLocationSpecifier(Location loc) = 0; - /// Print a newline and indent the printer to the start of the current - /// operation. - virtual void printNewline() = 0; - - /// Increase indentation. - virtual void increaseIndent() = 0; - - /// Decrease indentation. - virtual void decreaseIndent() = 0; - /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index de52fbd3f215c..5760e52d6efc4 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -418,6 +418,19 @@ class AsmPrinter::Impl { /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } + /// Print a newline and indent the printer to the start of the current + /// operation. + void printNewline() { + os << newLine; + os.indent(currentIndent); + } + + /// Increase indentation. + void increaseIndent() { currentIndent += indentWidth; } + + /// Decrease indentation. + void decreaseIndent() { currentIndent -= indentWidth; } + template inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const { llvm::interleaveComma(c, os, eachFn); @@ -532,6 +545,12 @@ class AsmPrinter::Impl { /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; + + /// The number of spaces used for indenting nested operations. + const static unsigned indentWidth = 2; + + /// This is the current indentation level for nested structures. + unsigned currentIndent = 0; }; } // namespace mlir @@ -1004,6 +1023,9 @@ class DummyAliasDialectAsmPrinter : public DialectAsmPrinter { /// The following are hooks of `DialectAsmPrinter` that are not necessary for /// determining potential aliases. + void printNewline() override {} + void increaseIndent() override {} + void decreaseIndent() override {} void printFloat(const APFloat &) override {} void printKeywordOrString(StringRef) override {} void printString(StringRef) override {} @@ -2892,6 +2914,13 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { { llvm::raw_string_ostream attrNameStr(attrName); Impl subPrinter(attrNameStr, state); + + // The values of currentIndent and newLine are assigned to the created + // subprinter, so that the indent level and number of printed lines can be + // tracked. + subPrinter.currentIndent = currentIndent; + subPrinter.newLine = newLine; + DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); } @@ -2906,6 +2935,13 @@ void AsmPrinter::Impl::printDialectType(Type type) { { llvm::raw_string_ostream typeNameStr(typeName); Impl subPrinter(typeNameStr, state); + + // The values of currentIndent and newLine are assigned to the created + // subprinter, so that the indent level and number of printed lines can be + // tracked. + subPrinter.currentIndent = currentIndent; + subPrinter.newLine = newLine; + DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); } @@ -2946,6 +2982,21 @@ raw_ostream &AsmPrinter::getStream() const { return impl->getStream(); } +void AsmPrinter::printNewline() { + assert(impl && "expected AsmPrinter::printNewLine to be overriden"); + impl->printNewline(); +} + +void AsmPrinter::increaseIndent() { + assert(impl && "expected AsmPrinter::increaseIndent to be overriden"); + impl->increaseIndent(); +} + +void AsmPrinter::decreaseIndent() { + assert(impl && "expected AsmPrinter::decreaseIndent to be overriden"); + impl->decreaseIndent(); +} + /// Print the given floating point value in a stablized form. void AsmPrinter::printFloat(const APFloat &value) { assert(impl && "expected AsmPrinter::printFloat to be overriden"); @@ -3276,19 +3327,6 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { printTrailingLocation(loc); } - /// Print a newline and indent the printer to the start of the current - /// operation. - void printNewline() override { - os << newLine; - os.indent(currentIndent); - } - - /// Increase indentation. - void increaseIndent() override { currentIndent += indentWidth; } - - /// Decrease indentation. - void decreaseIndent() override { currentIndent -= indentWidth; } - /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. @@ -3414,12 +3452,6 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter { // top-level we start with "builtin" as the default, so that the top-level // `module` operation prints as-is. SmallVector defaultDialectStack{"builtin"}; - - /// The number of spaces used for indenting nested operations. - const static unsigned indentWidth = 2; - - // This is the current indentation level for nested structures. - unsigned currentIndent = 0; }; } // namespace diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 5685004bbbd25..f98250c40f384 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -439,4 +439,10 @@ def TestCustomStorageCtorAttr : Test_Attr<"TestCustomStorageCtorAttr"> { let hasStorageCustomConstructor = 1; } +def TestAttrNewlineAndIndent : Test_Attr<"TestAttrNewlineAndIndent"> { + let mnemonic = "newline_and_indent"; + let parameters = (ins "::mlir::Type":$indentType); + let hasCustomAssemblyFormat = 1; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index 58909131e50a3..22ccedd89b1b3 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -430,6 +430,35 @@ bool TestConstMemorySpaceAttr::isValidPtrIntCast( return false; } +//===----------------------------------------------------------------------===// +// TestAttrNewlineAndIndent +//===----------------------------------------------------------------------===// + +Attribute TestAttrNewlineAndIndentAttr::parse(::mlir::AsmParser &parser, + ::mlir::Type type) { + Type indentType; + if (parser.parseLess()) { + return {}; + } + if (parser.parseType(indentType)) { + return {}; + } + if (parser.parseGreater()) { + return {}; + } + return get(parser.getContext(), indentType); +} + +void TestAttrNewlineAndIndentAttr::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer.increaseIndent(); + printer.printNewline(); + printer << getIndentType(); + printer.decreaseIndent(); + printer.printNewline(); + printer << ">"; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index ea20597231d58..070ad0e6e56c4 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -470,4 +470,9 @@ def TestMemrefType : Test_Type<"TestMemref", }]; } +def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> { + let mnemonic = "newline_and_indent"; + let hasCustomAssemblyFormat = 1; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index bea043f56fe21..4764a395a3afc 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -573,3 +573,30 @@ ::mlir::LogicalResult TestTensorType::verifyCompatibleBufferType( getElementType() == testMemref.getElementType(); return mlir::success(valid); } + +//===----------------------------------------------------------------------===// +// TestTypeNewlineAndIndent +//===----------------------------------------------------------------------===// + +Type TestTypeNewlineAndIndentType::parse(::mlir::AsmParser &parser) { + if (parser.parseLess()) { + return {}; + } + if (parser.parseKeyword("indented_content")) { + return {}; + } + if (parser.parseGreater()) { + return {}; + } + return get(parser.getContext()); +} + +void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const { + printer << "<"; + printer.increaseIndent(); + printer.printNewline(); + printer << "indented_content"; + printer.decreaseIndent(); + printer.printNewline(); + printer << ">"; +} diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir index 89ad3594eebd8..dca46a30e37f1 100644 --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s --strict-whitespace // CHECK-LABEL: func private @compoundA() // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> @@ -44,3 +44,19 @@ func.func private @hexdecimalInteger() attributes { // expected-error @below {{expected an integer}} sdg = #test.decimal_shape<1x0xb> } + +// ----- + +// CHECK-LABEL: @newlineAndIndent +// CHECK-SAME: indent = #test.newline_and_indent< +// CHECK-NEXT: {{^ }}!test.newline_and_indent< +// CHECK-NEXT: {{^ }}indented_content +// CHECK-NEXT: {{^ }}> +// CHECK-NEXT: {{^ }}> +func.func private @newlineAndIndent() attributes { + indent = #test.newline_and_indent< + !test.newline_and_indent< + indented_content + > + > +} diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir index 18175edc81cf0..c00d368fdab8b 100644 --- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s --strict-whitespace ////////////// // Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir' @@ -42,3 +42,13 @@ func.func @testInt(%A : !test.int, %B : !test.int, %C : !test func.func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { return } + +// CHECK-LABEL: @newlineAndIndent +// CHECK-SAME: !test.newline_and_indent< +// CHECK-NEXT: {{^ }}indented_content +// CHECK-NEXT: {{^ }}> +func.func @newlineAndIndent(%A : !test.newline_and_indent< + indented_content +>) { + return +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index 10a162f81ba13..ace73db03f569 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -958,9 +958,7 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx, void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx, MethodBody &os) { if (el->getValue() == "\\n") { - // FIXME: The newline should be `printer.printNewLine()`, i.e., handled by - // the printer. - os << tgfmt("$_printer << '\\n';\n", &ctx); + os << tgfmt("$_printer.printNewline();\n", &ctx); } else if (!el->getValue().empty()) { os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue()); } else {