Skip to content

Commit de6d010

Browse files
authored
[mlir][tblgen] Add custom parsing and printing within struct (#133939)
This PR extends the `struct` directive in tablegen to support nested `custom` directives. Note that this assumes/verifies that that `custom` directive has a single parameter. This enables defining custom field parsing and printing functions if the `struct` directive doesn't suffice. There is some existing potential downstream usage for it: https://github.com/openxla/stablehlo/blob/a3c7de92425e8035437dae67ab2318a82eca79a1/stablehlo/dialect/StablehloOps.cpp#L3102
1 parent 3b12bac commit de6d010

File tree

10 files changed

+440
-117
lines changed

10 files changed

+440
-117
lines changed

mlir/docs/DefiningDialects/AttributesAndTypes.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,9 +842,9 @@ if they are not present.
842842

843843
###### `struct` Directive
844844

845-
The `struct` directive accepts a list of variables to capture and will generate
846-
a parser and printer for a comma-separated list of key-value pairs. If an
847-
optional parameter is included in the `struct`, it can be elided. The variables
845+
The `struct` directive accepts a list of variables or directives to capture and
846+
will generate a parser and printer for a comma-separated list of key-value pairs.
847+
If an optional parameter is included in the `struct`, it can be elided. The variables
848848
are printed in the order they are specified in the argument list **but can be
849849
parsed in any order**. For example:
850850

@@ -876,6 +876,13 @@ assembly format of `` `<` struct(params) `>` `` will result in:
876876
The order in which the parameters are printed is the order in which they are
877877
declared in the attribute's or type's `parameter` list.
878878

879+
Passing `custom<Foo>($variable)` allows providing a custom printer and parser
880+
for the encapsulated variable. Check the
881+
[custom and ref directive](#custom-and-ref-directive) section for more
882+
information about how to define the printer and parser functions. Note that a
883+
custom directive within a struct directive can only encapsulate a single
884+
variable.
885+
879886
###### `custom` and `ref` directive
880887

881888
The `custom` directive is used to dispatch calls to user-defined printer and
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
2+
3+
// CHECK-LABEL: @test_struct_attr_roundtrip
4+
func.func @test_struct_attr_roundtrip() -> () {
5+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
6+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
7+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>
8+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = ?, opt_value = 1>} : () -> ()
9+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
10+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
11+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
12+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
13+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
14+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
15+
return
16+
}
17+
18+
// -----
19+
20+
// Verify all required parameters must be provided. `value` is missing.
21+
22+
// expected-error @below {{struct is missing required parameter: value}}
23+
"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
24+
25+
// -----
26+
27+
// Verify all keywords must be provided. All missing.
28+
29+
// expected-error @below {{expected valid keyword}}
30+
// expected-error @below {{expected a parameter name in struct}}
31+
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
32+
33+
// -----
34+
35+
// Verify all keywords must be provided. `type_str` missing.
36+
37+
// expected-error @below {{expected valid keyword}}
38+
// expected-error @below {{expected a parameter name in struct}}
39+
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
40+
41+
// -----
42+
43+
// Verify all keywords must be provided. `value` missing.
44+
45+
// expected-error @below {{expected valid keyword}}
46+
// expected-error @below {{expected a parameter name in struct}}
47+
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
48+
49+
// -----
50+
51+
// Verify invalid keyword provided.
52+
53+
// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
54+
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
55+
56+
// -----
57+
58+
// Verify duplicated keyword provided.
59+
60+
// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
61+
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
62+
63+
// -----
64+
65+
// Verify equals missing.
66+
67+
// expected-error @below {{expected '='}}
68+
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
369369
}];
370370
}
371371

372+
// Test `struct` with nested `custom` assembly format.
373+
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
374+
let mnemonic = "custom_struct";
375+
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
376+
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
377+
let assemblyFormat = [{
378+
`<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
379+
}];
380+
}
381+
372382
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
373383
let mnemonic = "nested_polynomial";
374384
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,49 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
316316
return success();
317317
}
318318

319+
//===----------------------------------------------------------------------===//
320+
// TestCustomStructAttr
321+
//===----------------------------------------------------------------------===//
322+
323+
static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
324+
if (ShapedType::isDynamic(value)) {
325+
p << "?";
326+
} else {
327+
p.printStrippedAttrOrType(value);
328+
}
329+
}
330+
331+
static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
332+
if (succeeded(p.parseOptionalQuestion())) {
333+
value = ShapedType::kDynamic;
334+
return success();
335+
}
336+
return p.parseInteger(value);
337+
}
338+
339+
static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
340+
if (attr && attr.size() == 1 && isa<IntegerAttr>(attr[0])) {
341+
p << cast<IntegerAttr>(attr[0]).getInt();
342+
} else {
343+
p.printStrippedAttrOrType(attr);
344+
}
345+
}
346+
347+
static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
348+
ArrayAttr &attr) {
349+
int64_t value;
350+
OptionalParseResult result = p.parseOptionalInteger(value);
351+
if (result.has_value()) {
352+
if (failed(result.value()))
353+
return failure();
354+
attr = ArrayAttr::get(
355+
p.getContext(),
356+
{IntegerAttr::get(IntegerType::get(p.getContext(), 64), value)});
357+
return success();
358+
}
359+
return p.parseAttribute(attr);
360+
}
361+
319362
//===----------------------------------------------------------------------===//
320363
// TestOpAsmAttrInterfaceAttr
321364
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/attr-or-type-format-invalid.td

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
3737
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
3838
let parameters = (ins "int":$v0);
3939
// CHECK: literals may only be used in the top-level section of the format
40-
// CHECK: expected a variable in `struct` argument list
40+
// CHECK: expected a parameter or `custom` directive in `struct` argument list
4141
let assemblyFormat = "`<` struct($v0, `,`) `>`";
4242
}
4343

4444
// Test struct directive cannot capture zero parameters.
4545
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
4646
let parameters = (ins "int":$v0);
47-
// CHECK: `struct` argument list expected a variable or directive
47+
// CHECK: `struct` argument list expected a parameter or directive
4848
let assemblyFormat = "`<` struct() $v0 `>`";
4949
}
5050

@@ -144,3 +144,24 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
144144
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
145145
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
146146
}
147+
148+
// Test `struct` with nested `custom` directive with multiple fields.
149+
def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
150+
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
151+
// CHECK: `struct` can only contain `custom` directives with a single argument
152+
let assemblyFormat = "struct(custom<Foo>($a, $b))";
153+
}
154+
155+
// Test `struct` with nested `custom` directive invalid parameter.
156+
def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
157+
let parameters = (ins OptionalParameter<"int">:$a);
158+
// CHECK: a `custom` directive nested within a `struct` must be passed a parameter
159+
let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
160+
}
161+
162+
// Test `custom` with nested `custom` directive invalid parameter.
163+
def InvalidTypeW : InvalidType<"InvalidTypeV", "invalid_v"> {
164+
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
165+
// CHECK: `custom` can only be used at the top-level context or within a `struct` directive
166+
let assemblyFormat = "custom<Foo>($a, custom<Bar>($b))";
167+
}

mlir/test/mlir-tblgen/attr-or-type-format.td

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,84 @@ def TypeS : TestType<"TestS"> {
736736
let assemblyFormat = "$a";
737737
}
738738

739+
/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
740+
741+
// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
742+
// ATTR: ::mlir::Type odsType) {
743+
// ATTR: bool _seen_v0 = false;
744+
// ATTR: bool _seen_v1 = false;
745+
// ATTR: bool _seen_v2 = false;
746+
// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
747+
// ATTR: if (odsParser.parseEqual())
748+
// ATTR: return {};
749+
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
750+
// ATTR: _seen_v0 = true;
751+
// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
752+
// ATTR: if (::mlir::failed(_result_v0))
753+
// ATTR: return {};
754+
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
755+
// ATTR: _seen_v1 = true;
756+
// ATTR: {
757+
// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
758+
// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
759+
// ATTR: if (::mlir::failed(odsCustomResult)) return {};
760+
// ATTR: if (::mlir::failed(_result_v1)) {
761+
// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
762+
// ATTR: return {};
763+
// ATTR: }
764+
// ATTR: }
765+
// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
766+
// ATTR: _seen_v2 = true;
767+
// ATTR: _result_v2 = ::mlir::FieldParser<AttrParamB>::parse(odsParser);
768+
// ATTR: if (::mlir::failed(_result_v2)) {
769+
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(), "failed to parse AttrT parameter 'v2' which is to be a `AttrParamB`");
770+
// ATTR: return {};
771+
// ATTR: }
772+
// ATTR: } else {
773+
// ATTR: return {};
774+
// ATTR: }
775+
// ATTR: return true;
776+
// ATTR: }
777+
// ATTR: do {
778+
// ATTR: ::llvm::StringRef _paramKey;
779+
// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
780+
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
781+
// ATTR-NEXT: "expected a parameter name in struct");
782+
// ATTR: return {};
783+
// ATTR: }
784+
// ATTR: if (!_loop_body(_paramKey)) return {};
785+
// ATTR: } while(!odsParser.parseOptionalComma());
786+
// ATTR: if (!_seen_v0)
787+
// ATTR: if (!_seen_v1)
788+
// ATTR: return TestTAttr::get(odsParser.getContext(),
789+
// ATTR: TestParamA((*_result_v0)),
790+
// ATTR: TestParamB((*_result_v1)),
791+
// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
792+
// ATTR: }
793+
794+
// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
795+
// ATTR: odsPrinter << "v0 = ";
796+
// ATTR: ::printAttrParamA(odsPrinter, getV0());
797+
// ATTR: odsPrinter << ", ";
798+
// ATTR: odsPrinter << "v1 = ";
799+
// ATTR: printNestedCustom(odsPrinter,
800+
// ATTR-NEXT: getV1());
801+
// ATTR: if (!(getV2() == AttrParamB())) {
802+
// ATTR: odsPrinter << "v2 = ";
803+
// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
804+
// ATTR: }
805+
806+
def AttrT : TestAttr<"TestT"> {
807+
let parameters = (ins
808+
AttrParamA:$v0,
809+
AttrParamB:$v1,
810+
OptionalParameter<"AttrParamB">:$v2
811+
);
812+
813+
let mnemonic = "attr_t";
814+
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
815+
}
816+
739817
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
740818
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
741819
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {

0 commit comments

Comments
 (0)