Skip to content

Commit faa7628

Browse files
committed
[mlir][tblgen] Add custom parsing and printing within struct
1 parent 41b83b4 commit faa7628

File tree

9 files changed

+453
-158
lines changed

9 files changed

+453
-158
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
2+
3+
// CHECK-LABEL: @test_struct_attr_roundtrip
4+
func.func @test_struct_attr_roundtrip() -> () {
5+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
6+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
7+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
8+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
9+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
10+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
11+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
12+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
13+
return
14+
}
15+
16+
// -----
17+
18+
// Verify all required parameters must be provided. `value` is missing.
19+
20+
// expected-error @below {{struct is missing required parameter: value}}
21+
"test.op"() {attr = #test.custom_struct<type_str = "struct">} : () -> ()
22+
23+
// -----
24+
25+
// Verify all keywords must be provided. All missing.
26+
27+
// expected-error @below {{expected valid keyword}}
28+
// expected-error @below {{expected a parameter name in struct}}
29+
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
30+
31+
// -----
32+
33+
// Verify all keywords must be provided. `type_str` missing.
34+
35+
// expected-error @below {{expected valid keyword}}
36+
// expected-error @below {{expected a parameter name in struct}}
37+
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
38+
39+
// -----
40+
41+
// Verify all keywords must be provided. `value` missing.
42+
43+
// expected-error @below {{expected valid keyword}}
44+
// expected-error @below {{expected a parameter name in struct}}
45+
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
46+
47+
// -----
48+
49+
// Verify invalid keyword provided.
50+
51+
// expected-error @below {{duplicate or unknown struct parameter name: type_str2}}
52+
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
53+
54+
// -----
55+
56+
// Verify duplicated keyword provided.
57+
58+
// expected-error @below {{duplicate or unknown struct parameter name: type_str}}
59+
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
60+
61+
// -----
62+
63+
// Verify equals missing.
64+
65+
// expected-error @below {{expected '='}}
66+
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,16 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
369369
}];
370370
}
371371

372+
// Test `struct` with nested `custom` assembly format.
373+
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
374+
let mnemonic = "custom_struct";
375+
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
376+
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
377+
let assemblyFormat = [{
378+
`<` struct($type_str, custom<CustomStructAttr>($value), custom<CustomOptStructFieldAttr>($opt_value)) `>`
379+
}];
380+
}
381+
372382
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
373383
let mnemonic = "nested_polynomial";
374384
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,33 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
316316
return success();
317317
}
318318

319+
//===----------------------------------------------------------------------===//
320+
// TestCustomStructAttr
321+
//===----------------------------------------------------------------------===//
322+
323+
static void printCustomStructAttr(AsmPrinter &p, int64_t value) {
324+
p.printStrippedAttrOrType(value);
325+
}
326+
327+
static ParseResult parseCustomStructAttr(AsmParser &p, int64_t &value) {
328+
if (failed(p.parseInteger(value))) {
329+
return failure();
330+
}
331+
return success();
332+
}
333+
334+
static void printCustomOptStructFieldAttr(AsmPrinter &p, ArrayAttr attr) {
335+
p.printStrippedAttrOrType(attr);
336+
}
337+
338+
static ParseResult parseCustomOptStructFieldAttr(AsmParser &p,
339+
ArrayAttr &attr) {
340+
if (failed(p.parseAttribute(attr))) {
341+
return failure();
342+
}
343+
return success();
344+
}
345+
319346
//===----------------------------------------------------------------------===//
320347
// TestOpAsmAttrInterfaceAttr
321348
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/attr-or-type-format-invalid.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
3737
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
3838
let parameters = (ins "int":$v0);
3939
// CHECK: literals may only be used in the top-level section of the format
40-
// CHECK: expected a variable in `struct` argument list
40+
// CHECK: expected a variable or `custom` directive in `struct` argument list
4141
let assemblyFormat = "`<` struct($v0, `,`) `>`";
4242
}
4343

@@ -144,3 +144,17 @@ def InvalidTypeT : InvalidType<"InvalidTypeT", "invalid_t"> {
144144
// CHECK: `custom` directive with no bound parameters cannot be used as optional group anchor
145145
let assemblyFormat = "$a (`(` custom<Foo>(ref($a))^ `)`)?";
146146
}
147+
148+
// Test `struct` with nested `custom` directive with multiple fields.
149+
def InvalidTypeU : InvalidType<"InvalidTypeU", "invalid_u"> {
150+
let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
151+
// CHECK: `struct` can only contain `custom` directives with a single argument
152+
let assemblyFormat = "struct(custom<Foo>($a, $b))";
153+
}
154+
155+
// Test `struct` with nested `custom` directive invalid parameter.
156+
def InvalidTypeV : InvalidType<"InvalidTypeV", "invalid_v"> {
157+
let parameters = (ins OptionalParameter<"int">:$a);
158+
// CHECK: a `custom` directive nested within a `struct` must be passed a parameter
159+
let assemblyFormat = "struct($a, custom<Foo>(ref($a)))";
160+
}

mlir/test/mlir-tblgen/attr-or-type-format.td

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,78 @@ def TypeS : TestType<"TestS"> {
736736
let assemblyFormat = "$a";
737737
}
738738

739+
/// Test that a `struct` with nested `custom` parser and printer are generated correctly.
740+
741+
// ATTR: ::mlir::Attribute TestTAttr::parse(::mlir::AsmParser &odsParser,
742+
// ATTR: ::mlir::Type odsType) {
743+
// ATTR: bool _seen_v0 = false;
744+
// ATTR: bool _seen_v1 = false;
745+
// ATTR: bool _seen_v2 = false;
746+
// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
747+
// ATTR: if (odsParser.parseEqual())
748+
// ATTR: return {};
749+
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
750+
// ATTR: _seen_v0 = true;
751+
// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
752+
// ATTR: if (::mlir::failed(_result_v0))
753+
// ATTR: return {};
754+
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
755+
// ATTR: _seen_v1 = true;
756+
// ATTR: {
757+
// ATTR: auto odsCustomResult = parseNestedCustom(odsParser,
758+
// ATTR-NEXT: ::mlir::detail::unwrapForCustomParse(_result_v1));
759+
// ATTR: if (::mlir::failed(odsCustomResult)) return {};
760+
// ATTR: if (::mlir::failed(_result_v1)) {
761+
// ATTR: odsParser.emitError(odsCustomLoc, "custom parser failed to parse parameter 'v1'");
762+
// ATTR: return {};
763+
// ATTR: }
764+
// ATTR: }
765+
// ATTR: } else if (!_seen_v2 && _paramKey == "v2") {
766+
// ATTR: } else {
767+
// ATTR: return {};
768+
// ATTR: }
769+
// ATTR: return true;
770+
// ATTR: }
771+
// ATTR: do {
772+
// ATTR: ::llvm::StringRef _paramKey;
773+
// ATTR: if (odsParser.parseKeyword(&_paramKey)) {
774+
// ATTR: odsParser.emitError(odsParser.getCurrentLocation(),
775+
// ATTR-NEXT: "expected a parameter name in struct");
776+
// ATTR: return {};
777+
// ATTR: }
778+
// ATTR: if (!_loop_body(_paramKey)) return {};
779+
// ATTR: } while(!odsParser.parseOptionalComma());
780+
// ATTR: if (!_seen_v0)
781+
// ATTR: if (!_seen_v1)
782+
// ATTR: return TestTAttr::get(odsParser.getContext(),
783+
// ATTR: TestParamA((*_result_v0)),
784+
// ATTR: TestParamB((*_result_v1)),
785+
// ATTR: AttrParamB((_result_v2.value_or(AttrParamB()))));
786+
// ATTR: }
787+
788+
// ATTR: void TestTAttr::print(::mlir::AsmPrinter &odsPrinter) const {
789+
// ATTR: odsPrinter << "v0 = ";
790+
// ATTR: ::printAttrParamA(odsPrinter, getV0());
791+
// ATTR: odsPrinter << ", ";
792+
// ATTR: odsPrinter << "v1 = ";
793+
// ATTR: printNestedCustom(odsPrinter,
794+
// ATTR-NEXT: getV1());
795+
// ATTR: if (!(getV2() == AttrParamB())) {
796+
// ATTR: odsPrinter << "v2 = ";
797+
// ATTR: odsPrinter.printStrippedAttrOrType(getV2());
798+
// ATTR: }
799+
800+
def AttrT : TestAttr<"TestT"> {
801+
let parameters = (ins
802+
AttrParamA:$v0,
803+
AttrParamB:$v1,
804+
OptionalParameter<"AttrParamB">:$v2
805+
);
806+
807+
let mnemonic = "attr_t";
808+
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
809+
}
810+
739811
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
740812
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
741813
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {

0 commit comments

Comments
 (0)