Skip to content

Commit 3d95ba3

Browse files
committed
[mlir] Add struct parsing and printing utilities
1 parent 41b83b4 commit 3d95ba3

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,34 @@ class AsmPrinter {
177177
*this << attrOrType;
178178
}
179179

180+
//===----------------------------------------------------------------------===//
181+
// Struct Printing
182+
//===----------------------------------------------------------------------===//
183+
184+
template <typename T>
185+
void printStruct(std::pair<const char *, T> field) {
186+
auto &os = getStream();
187+
os << std::get<0>(field) << " = ";
188+
printStrippedAttrOrType(std::get<1>(field));
189+
}
190+
191+
/// Print a comma-separated list of key-value pairs. This performs similar
192+
/// parsing as the the assembly format's `struct` directive printer.
193+
///
194+
/// Example:
195+
/// <
196+
/// foo = foo_value,
197+
/// bar = bar_value,
198+
/// ...
199+
/// >
200+
template <typename T, typename... Ts>
201+
void printStruct(std::pair<const char *, T> field, Ts... rest) {
202+
auto &os = getStream();
203+
printStruct(field);
204+
os << ", ";
205+
printStruct(rest...);
206+
}
207+
180208
/// Print the given attribute without its type. The corresponding parser must
181209
/// provide a valid type for the attribute.
182210
virtual void printAttributeWithoutType(Attribute attr);
@@ -1409,6 +1437,26 @@ class AsmParser {
14091437
return CyclicParseReset(this);
14101438
}
14111439

1440+
//===----------------------------------------------------------------------===//
1441+
// Struct Parsing
1442+
//===----------------------------------------------------------------------===//
1443+
1444+
/// Parse a comma-separated list of key-value pairs with a specified
1445+
/// delimiter. This performs similar parsing as the the assembly format
1446+
/// `struct` directive parser with a specified delimiter. The variables are
1447+
/// printed in the order they are specified in the argument list but can be
1448+
/// parsed in any order.
1449+
///
1450+
/// Example:
1451+
/// <
1452+
/// foo = something_parsed_by_a_custom_parser,
1453+
/// bar = something_parsed_by_a_different_custom_parser,
1454+
/// ...
1455+
/// >
1456+
virtual ParseResult
1457+
parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
1458+
ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
1459+
14121460
protected:
14131461
/// Parse a handle to a resource within the assembly format for the given
14141462
/// dialect.

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
570570
parser.getState().cyclicParsingStack.pop_back();
571571
}
572572

573+
//===----------------------------------------------------------------------===//
574+
// Struct Parsing
575+
//===----------------------------------------------------------------------===//
576+
577+
/// Parse a comma-separated list of key-value pairs with a specified
578+
/// delimiter.
579+
ParseResult
580+
parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
581+
ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
582+
assert(keywords.size() == parseFuncs.size());
583+
auto keyError = [&]() -> ParseResult {
584+
InFlightDiagnostic parseError =
585+
emitError(getCurrentLocation(), "expected one of: ");
586+
llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
587+
parseError << '`' << kw << '`';
588+
});
589+
return parseError;
590+
};
591+
SmallVector<bool> seen(keywords.size(), false);
592+
DenseMap<StringRef, size_t> keywordToIndex;
593+
for (auto &&[idx, keyword] : llvm::enumerate(keywords))
594+
keywordToIndex[keyword] = idx;
595+
return parseCommaSeparatedList(
596+
delimiter,
597+
[&]() -> ParseResult {
598+
StringRef keyword;
599+
if (failed(parseOptionalKeyword(&keyword)))
600+
return keyError();
601+
if (!keywordToIndex.contains(keyword))
602+
return keyError();
603+
size_t idx = keywordToIndex[keyword];
604+
if (seen[idx]) {
605+
return emitError(getCurrentLocation(), "duplicated `")
606+
<< keyword << "` entry";
607+
}
608+
if (failed(parseEqual()))
609+
return failure();
610+
if (failed(parseFuncs[idx]()))
611+
return failure();
612+
seen[idx] = true;
613+
return success();
614+
},
615+
"parse struct");
616+
}
617+
573618
//===--------------------------------------------------------------------===//
574619
// Code Completion
575620
//===--------------------------------------------------------------------===//
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
2+
3+
// CHECK-LABEL: @test_struct_attr_roundtrip
4+
func.func @test_struct_attr_roundtrip() -> () {
5+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_param = [3, 3]>
6+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_param = [3, 3]>} : () -> ()
7+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_param = [3, 3]>
8+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_param = [3, 3]>} : () -> ()
9+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
10+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
11+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
12+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
13+
return
14+
}
15+
16+
// -----
17+
18+
// Verify all keywords must be provided. All missing.
19+
20+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
21+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_param`}}
22+
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
23+
24+
// -----
25+
26+
// Verify all keywords must be provided. `type_str` missing.
27+
28+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
29+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_param`}}
30+
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_param = [3, 3]>} : () -> ()
31+
32+
// -----
33+
34+
// Verify all keywords must be provided. `value` missing.
35+
36+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
37+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_param`}}
38+
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
39+
40+
// -----
41+
42+
// Verify invalid keyword provided.
43+
44+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
45+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_param`}}
46+
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
47+
48+
// -----
49+
50+
// Verify duplicated keyword provided.
51+
52+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
53+
// expected-error @below {{duplicated `type_str` entry}}
54+
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
55+
56+
// -----
57+
58+
// Verify equals missing.
59+
60+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
61+
// expected-error @below {{expected '='}}
62+
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
369369
}];
370370
}
371371

372+
// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
373+
// parser and printer.
374+
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
375+
let mnemonic = "custom_struct";
376+
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
377+
OptionalParameter<"mlir::ArrayAttr">:$opt_param);
378+
let hasCustomAssemblyFormat = 1;
379+
}
380+
372381
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
373382
let mnemonic = "nested_polynomial";
374383
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,41 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
316316
return success();
317317
}
318318

319+
//===----------------------------------------------------------------------===//
320+
// TestCustomStructAttr
321+
//===----------------------------------------------------------------------===//
322+
323+
Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
324+
std::string typeStr;
325+
int64_t value;
326+
FailureOr<ArrayAttr> optParam;
327+
if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
328+
{"type_str", "value", "opt_param"},
329+
{[&]() { return p.parseString(&typeStr); },
330+
[&]() { return p.parseInteger(value); },
331+
[&]() {
332+
optParam = mlir::FieldParser<ArrayAttr>::parse(p);
333+
return success(succeeded(optParam));
334+
}}))) {
335+
p.emitError(p.getCurrentLocation())
336+
<< "failed parsing `TestCustomStructAttr`";
337+
return {};
338+
}
339+
return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
340+
optParam.value_or(ArrayAttr()));
341+
}
342+
343+
void TestCustomStructAttr::print(AsmPrinter &p) const {
344+
p << "<";
345+
p.printStruct(std::make_pair("type_str", getTypeStr()),
346+
std::make_pair("value", getValue()));
347+
if (getOptParam() != ArrayAttr()) {
348+
p << ", ";
349+
p.printStruct(std::make_pair("opt_param", getOptParam()));
350+
}
351+
p << ">";
352+
}
353+
319354
//===----------------------------------------------------------------------===//
320355
// TestOpAsmAttrInterfaceAttr
321356
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)