Skip to content

Commit 8e20e50

Browse files
committed
[mlir] Add struct parsing and printing utilities
1 parent 41b83b4 commit 8e20e50

File tree

6 files changed

+218
-0
lines changed

6 files changed

+218
-0
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,29 @@ class AsmPrinter {
238238

239239
void printDimensionList(ArrayRef<int64_t> shape);
240240

241+
//===----------------------------------------------------------------------===//
242+
// Struct Printing
243+
//===----------------------------------------------------------------------===//
244+
245+
/// Print a comma-separated list of key-value pairs using the provided
246+
/// `keywords` and corresponding printing functions. This performs similar
247+
/// printing as the the assembly format's `struct` directive printer, but
248+
/// allows bringing in custom printers for fields.
249+
///
250+
/// Example:
251+
/// <
252+
/// foo = foo_value,
253+
/// bar = bar_value,
254+
/// ...
255+
/// >
256+
virtual void
257+
printStruct(ArrayRef<StringRef> keywords,
258+
ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs);
259+
260+
//===----------------------------------------------------------------------===//
261+
// Cyclic Printing
262+
//===----------------------------------------------------------------------===//
263+
241264
/// Class used to automatically end a cyclic region on destruction.
242265
class CyclicPrintReset {
243266
public:
@@ -1409,6 +1432,26 @@ class AsmParser {
14091432
return CyclicParseReset(this);
14101433
}
14111434

1435+
//===----------------------------------------------------------------------===//
1436+
// Struct Parsing
1437+
//===----------------------------------------------------------------------===//
1438+
1439+
/// Parse a comma-separated list of key-value pairs with a specified
1440+
/// delimiter. This performs similar parsing as the the assembly format
1441+
/// `struct` directive parser with custom delimiter and/or field parsing. The
1442+
/// variables are printed in the order they are specified in the argument list
1443+
/// but can be parsed in any order.
1444+
///
1445+
/// Example:
1446+
/// <
1447+
/// foo = something_parsed_by_a_custom_parser,
1448+
/// bar = something_parsed_by_a_different_custom_parser,
1449+
/// ...
1450+
/// >
1451+
virtual ParseResult
1452+
parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
1453+
ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) = 0;
1454+
14121455
protected:
14131456
/// Parse a handle to a resource within the assembly format for the given
14141457
/// dialect.

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,51 @@ class AsmParserImpl : public BaseT {
570570
parser.getState().cyclicParsingStack.pop_back();
571571
}
572572

573+
//===----------------------------------------------------------------------===//
574+
// Struct Parsing
575+
//===----------------------------------------------------------------------===//
576+
577+
/// Parse a comma-separated list of key-value pairs with a specified
578+
/// delimiter.
579+
ParseResult
580+
parseStruct(Delimiter delimiter, ArrayRef<StringRef> keywords,
581+
ArrayRef<llvm::function_ref<ParseResult()>> parseFuncs) override {
582+
assert(keywords.size() == parseFuncs.size());
583+
auto keyError = [&]() -> ParseResult {
584+
InFlightDiagnostic parseError =
585+
emitError(getCurrentLocation(), "expected one of: ");
586+
llvm::interleaveComma(keywords, parseError, [&](StringRef kw) {
587+
parseError << '`' << kw << '`';
588+
});
589+
return parseError;
590+
};
591+
SmallVector<bool> seen(keywords.size(), false);
592+
DenseMap<StringRef, size_t> keywordToIndex;
593+
for (auto &&[idx, keyword] : llvm::enumerate(keywords))
594+
keywordToIndex[keyword] = idx;
595+
return parseCommaSeparatedList(
596+
delimiter,
597+
[&]() -> ParseResult {
598+
StringRef keyword;
599+
if (failed(parseOptionalKeyword(&keyword)))
600+
return keyError();
601+
if (!keywordToIndex.contains(keyword))
602+
return keyError();
603+
size_t idx = keywordToIndex[keyword];
604+
if (seen[idx]) {
605+
return emitError(getCurrentLocation(), "duplicated `")
606+
<< keyword << "` entry";
607+
}
608+
if (failed(parseEqual()))
609+
return failure();
610+
if (failed(parseFuncs[idx]()))
611+
return failure();
612+
seen[idx] = true;
613+
return success();
614+
},
615+
"parse struct");
616+
}
617+
573618
//===--------------------------------------------------------------------===//
574619
// Code Completion
575620
//===--------------------------------------------------------------------===//

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3002,6 +3002,28 @@ void AsmPrinter::printDimensionList(ArrayRef<int64_t> shape) {
30023002
detail::printDimensionList(getStream(), shape);
30033003
}
30043004

3005+
//===----------------------------------------------------------------------===//
3006+
// Struct Printing
3007+
//===----------------------------------------------------------------------===//
3008+
3009+
/// Print a comma-separated list of key-value pairs.
3010+
void AsmPrinter::printStruct(
3011+
ArrayRef<StringRef> keywords,
3012+
ArrayRef<llvm::function_ref<void(AsmPrinter &p)>> printFuncs) {
3013+
DenseMap<StringRef, llvm::function_ref<void(AsmPrinter & p)>> keywordToFunc;
3014+
for (auto &&[kw, printFunc] : llvm::zip(keywords, printFuncs))
3015+
keywordToFunc[kw] = printFunc;
3016+
auto &os = getStream();
3017+
llvm::interleaveComma(keywords, os, [&](StringRef kw) {
3018+
os << kw << " = ";
3019+
keywordToFunc[kw](*this);
3020+
});
3021+
}
3022+
3023+
//===----------------------------------------------------------------------===//
3024+
// Cyclic Printing
3025+
//===----------------------------------------------------------------------===//
3026+
30053027
LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
30063028
return impl->pushCyclicPrinting(opaquePointer);
30073029
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
2+
3+
// CHECK-LABEL: @test_struct_attr_roundtrip
4+
func.func @test_struct_attr_roundtrip() -> () {
5+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
6+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>} : () -> ()
7+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2, opt_value = [3, 3]>
8+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct", opt_value = [3, 3]>} : () -> ()
9+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
10+
"test.op"() {attr = #test.custom_struct<type_str = "struct", value = 2>} : () -> ()
11+
// CHECK: attr = #test.custom_struct<type_str = "struct", value = 2>
12+
"test.op"() {attr = #test.custom_struct<value = 2, type_str = "struct">} : () -> ()
13+
return
14+
}
15+
16+
// -----
17+
18+
// Verify all keywords must be provided. All missing.
19+
20+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
21+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
22+
"test.op"() {attr = #test.custom_struct<"struct", 2>} : () -> ()
23+
24+
// -----
25+
26+
// Verify all keywords must be provided. `type_str` missing.
27+
28+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
29+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
30+
"test.op"() {attr = #test.custom_struct<"struct", value = 2, opt_value = [3, 3]>} : () -> ()
31+
32+
// -----
33+
34+
// Verify all keywords must be provided. `value` missing.
35+
36+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
37+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
38+
"test.op"() {attr = #test.custom_struct<type_str = "struct", 2>} : () -> ()
39+
40+
// -----
41+
42+
// Verify invalid keyword provided.
43+
44+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
45+
// expected-error @below {{expected one of: `type_str`, `value`, `opt_value`}}
46+
"test.op"() {attr = #test.custom_struct<type_str2 = "struct", value = 2>} : () -> ()
47+
48+
// -----
49+
50+
// Verify duplicated keyword provided.
51+
52+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
53+
// expected-error @below {{duplicated `type_str` entry}}
54+
"test.op"() {attr = #test.custom_struct<type_str = "struct", type_str = "struct2", value = 2>} : () -> ()
55+
56+
// -----
57+
58+
// Verify equals missing.
59+
60+
// expected-error @below {{failed parsing `TestCustomStructAttr`}}
61+
// expected-error @below {{expected '='}}
62+
"test.op"() {attr = #test.custom_struct<type_str "struct", value = 2>} : () -> ()

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,15 @@ def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
369369
}];
370370
}
371371

372+
// Test AsmParser::parseStruct and AsmPrinter::printStruct APIs through the custom
373+
// parser and printer.
374+
def TestCustomStructAttr : Test_Attr<"TestCustomStruct"> {
375+
let mnemonic = "custom_struct";
376+
let parameters = (ins "mlir::StringAttr":$type_str, "int64_t":$value,
377+
OptionalParameter<"mlir::ArrayAttr">:$opt_value);
378+
let hasCustomAssemblyFormat = 1;
379+
}
380+
372381
def NestedPolynomialAttr : Test_Attr<"NestedPolynomialAttr"> {
373382
let mnemonic = "nested_polynomial";
374383
let parameters = (ins Polynomial_IntPolynomialAttr:$poly);

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,43 @@ static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
316316
return success();
317317
}
318318

319+
//===----------------------------------------------------------------------===//
320+
// TestCustomStructAttr
321+
//===----------------------------------------------------------------------===//
322+
323+
Attribute TestCustomStructAttr::parse(AsmParser &p, Type type) {
324+
std::string typeStr;
325+
int64_t value;
326+
FailureOr<ArrayAttr> optValue;
327+
if (failed(p.parseStruct(AsmParser::Delimiter::LessGreater,
328+
{"type_str", "value", "opt_value"},
329+
{[&]() { return p.parseString(&typeStr); },
330+
[&]() { return p.parseInteger(value); },
331+
[&]() {
332+
optValue = mlir::FieldParser<ArrayAttr>::parse(p);
333+
return success(succeeded(optValue));
334+
}}))) {
335+
p.emitError(p.getCurrentLocation())
336+
<< "failed parsing `TestCustomStructAttr`";
337+
return {};
338+
}
339+
return get(p.getContext(), StringAttr::get(p.getContext(), typeStr), value,
340+
optValue.value_or(ArrayAttr()));
341+
}
342+
343+
void TestCustomStructAttr::print(AsmPrinter &p) const {
344+
p << "<";
345+
p.printStruct(
346+
{"type_str", "value"},
347+
{[&](AsmPrinter &p) { p.printStrippedAttrOrType(getTypeStr()); },
348+
[&](AsmPrinter &p) { p.printStrippedAttrOrType(getValue()); }});
349+
if (getOptValue() != ArrayAttr()) {
350+
p << ", opt_value = ";
351+
p.printStrippedAttrOrType(getOptValue());
352+
}
353+
p << ">";
354+
}
355+
319356
//===----------------------------------------------------------------------===//
320357
// TestOpAsmAttrInterfaceAttr
321358
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)