Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mlir/docs/DefiningDialects/AttributesAndTypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,9 @@ if they are not present.

###### `struct` Directive

The `struct` directive accepts a list of variables to capture and will generate
a parser and printer for a comma-separated list of key-value pairs. If an
optional parameter is included in the `struct`, it can be elided. The variables
The `struct` directive accepts a list of variables or directives to capture and
will generate a parser and printer for a comma-separated list of key-value pairs.
If an optional parameter is included in the `struct`, it can be elided. The variables
are printed in the order they are specified in the argument list **but can be
parsed in any order**. For example:

Expand Down Expand Up @@ -876,6 +876,13 @@ assembly format of `` `<` struct(params) `>` `` will result in:
The order in which the parameters are printed is the order in which they are
declared in the attribute's or type's `parameter` list.

Passing `custom<Foo>($variable)` allows providing a custom printer and parser
for the encapsulated variable. Check the
[custom and ref directive](#custom-and-ref-directive) section for more
information about how to define the printer and parser functions. Note that a
custom directive within a struct directive can only encapsulate a single
variable.

###### `custom` and `ref` directive

The `custom` directive is used to dispatch calls to user-defined printer and
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/IR/custom-struct-attr-roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s

// CHECK-LABEL: @test_struct_attr_roundtrip
func.func @test_struct_attr_roundtrip() -> () {
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
return
}

// -----

// Verify all required parameters must be provided. `value` is missing.

// expected-error @below {{struct is missing required parameter: value}}
"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()

// -----

// Verify all keywords must be provided. All missing.

// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()

// -----

// Verify all keywords must be provided. `type_str` missing.

// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()

// -----

// Verify all keywords must be provided. `value` missing.

// expected-error @below {{expected valid keyword}}
// expected-error @below {{expected a parameter name in struct}}
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()

// -----

// Verify invalid keyword provided.

// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()

// -----

// Verify duplicated keyword provided.

// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()

// -----

// Verify equals missing.

// expected-error @below {{expected '='}}
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()
10 changes: 10 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
}];
}

// Test `struct` with nested `custom` assembly format.
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
let mnemonic = "custom_struct";
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
let assemblyFormat = [{
`<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
}];
}

def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
let mnemonic = "nested_polynomial";
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,49 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
return success();
}

//===----------------------------------------------------------------------===//
// TestCustomStructAttr
//===----------------------------------------------------------------------===//

static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
if (ShapedType::isDynamic(value)) {
p << "?";
} else {
p.printStrippedAttrOrType(value);
}
}

static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
if (succeeded(p.parseOptionalQuestion())) {
value = ShapedType::kDynamic;
return success();
}
return p.parseInteger(value);
}

static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) {
p << cast<IntegerAttr>(attr[0]).getInt();
} else {
p.printStrippedAttrOrType(attr);
}
}

static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
ArrayAttr &attr) {
int64_t value;
OptionalParseResult result = p.parseOptionalInteger(value);
if (result.has_value()) {
if (failed(result.value()))
return failure();
attr = ArrayAttr::get(
p.getContext(),
{IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)});
return success();
}
return p.parseAttribute(attr);
}

//===----------------------------------------------------------------------===//
// TestOpAsmAttrInterfaceAttr
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 23 additions & 2 deletions mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let parameters = (ins "int":$v0);
// CHECK: literals may only be used in the top-level section of the format
// CHECK: expected a variable in `struct` argument list
// CHECK: expected a parameter or `custom` directive in `struct` argument list
let assemblyFormat = "`<` struct($v0, `,`) `>`";
}

// Test struct directive cannot capture zero parameters.
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
let parameters = (ins "int":$v0);
// CHECK: `struct` argument list expected a variable or directive
// CHECK: `struct` argument list expected a parameter or directive
let assemblyFormat = "`<` struct() $v0 `>`";
}

Expand Down Expand Up @@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
}

// Test `struct` with nested `custom` directive with multiple fields.
def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
// CHECK: `struct` can only contain `custom` directives with a single argument
let assemblyFormat = "struct(custom<Foo>($a, $b))";
}

// Test `struct` with nested `custom` directive invalid parameter.
def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
let parameters = (ins OptionalParameter<"int">:$a);
// CHECK: a `custom` directive nested within a `struct` must be passed a parameter
let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
}

// Test `custom` with nested `custom` directive invalid parameter.
def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
// CHECK: `custom` can only be used at the top-level context or within a `struct` directive
let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
}
78 changes: 78 additions & 0 deletions mlir/test/mlir-tblgen/attr-or-type-format.td
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
let assemblyFormat = "$a";
}

/// Test that a `struct` with nested `custom` parser and printer are generated correctly.

// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
// ATTR: ::mlir::Type odsType) {
// ATTR: bool _seen_v0 = false;
// ATTR: bool _seen_v1 = false;
// ATTR: bool _seen_v2 = false;
// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
// ATTR: if (odsParser.parseEqual())
// ATTR: return {};
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
// ATTR: _seen_v0 = true;
// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
// ATTR: if (::mlir::failed(_result_v0))
// ATTR: return {};
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
// ATTR: _seen_v1 = true;
// ATTR: {
// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
// ATTR: if (::mlir::failed(odsCustomResult)) return {};
// ATTR: if (::mlir::failed(_result_v1)) {
// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
// ATTR: return {};
// ATTR: }
// ATTR: }
// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
// ATTR: _seen_v2 = true;
// ATTR: _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
// ATTR: if (::mlir::failed(_result_v2)) {
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
// ATTR: return {};
// ATTR: }
// ATTR: } else {
// ATTR: return {};
// ATTR: }
// ATTR: return true;
// ATTR: }
// ATTR: do {
// ATTR: ::llvm::StringRef _paramKey;
// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
// ATTR-NEXT: "expected a parameter name in struct");
// ATTR: return {};
// ATTR: }
// ATTR: if (!_loop_body(_paramKey)) return {};
// ATTR: } while(!odsParser.parseOptionalComma());
// ATTR: if (!_seen_v0)
// ATTR: if (!_seen_v1)
// ATTR: return TestTAttr::get(odsParser.getContext(),
// ATTR: TestParamA((*_result_v0)),
// ATTR: TestParamB((*_result_v1)),
// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
// ATTR: }

// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
// ATTR: odsPrinter << "v0 = ";
// ATTR: ::printAttrParamA(odsPrinter, getV0());
// ATTR: odsPrinter << ", ";
// ATTR: odsPrinter << "v1 = ";
// ATTR: printNestedCustom(odsPrinter,
// ATTR-NEXT: getV1());
// ATTR: if (!(getV2() == AttrParamB())) {
// ATTR: odsPrinter << "v2 = ";
// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
// ATTR: }

def AttrT : TestAttr<"TestT"> {
let parameters = (ins
AttrParamA:$v0,
AttrParamB:$v1,
OptionalParameter<"AttrParamB">:$v2
);

let mnemonic = "attr_t";
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
}

// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
Expand Down
Loading
Loading