Skip to content

Commit 0acc260

Browse files
committed
[mlir][linalg] Support generating builders for named op attributes
This commit adds support to generate an additional builder for each named op that has attributes. This gives better experience when creating the named ops. Along the way adds support for i64. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D94733
1 parent 5e4480b commit 0acc260

File tree

2 files changed

+70
-6
lines changed

2 files changed

+70
-6
lines changed

mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
8888
// ODS: F32:$f32_attr,
8989
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
9090
// ODS: I32:$i32_attr,
91+
// ODS: I64:$i64_attr,
9192
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
9293
// ODS: OptionalAttr<F32>:$optional_attr
9394
//
@@ -96,6 +97,7 @@ def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
9697
attr(
9798
f32_attr: f32,
9899
i32_attr: i32,
100+
i64_attr: i64,
99101
fvec_attr: 4xf32,
100102
ivec_attr: 5x6xi32,
101103
array_attr : f32[],
@@ -126,6 +128,7 @@ def test5(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F))
126128
I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c)));
127129
}
128130

131+
// Test documentation
129132
// ODS-LABEL: def Test6Op
130133
// ODS: let summary = [{ My magic op. }];
131134
// ODS-NEXT: let description = [{
@@ -144,3 +147,18 @@ It has one output.
144147
{
145148
C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
146149
}
150+
151+
// Test attribute builder
152+
// ODS-LABEL: def Test7Op
153+
// ODS: OpBuilderDAG<
154+
// ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
155+
// ODS: "ValueRange":$outputs, "Attribute":$attr_a, "Attribute":$attr_b)
156+
// ODS: $_state.addAttribute("attr_a", attr_a);
157+
// ODS: $_state.addAttribute("attr_b", attr_b);
158+
//
159+
ods_def<Test7Op>:
160+
def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M))
161+
attr(attr_a: f32, attr_b: 4xi32)
162+
{
163+
C(m) = std_addf<k>(std_mulf(A(m, k), B(k)));
164+
}

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
17681768
std::string odsType = llvm::StringSwitch<std::string>(elementType)
17691769
.Case("f32", "F32")
17701770
.Case("i32", "I32")
1771+
.Case("i64", "I64")
17711772
.Default("");
17721773
if (odsType.empty()) {
17731774
parser.emitError("unimplemented support for attribute element type: " +
@@ -1811,7 +1812,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18111812
let regions = (region AnyRegion:$region);
18121813
18131814
let skipDefaultBuilders = 1;
1814-
let builders = [ OpBuilderDAG<
1815+
let builders = [
1816+
OpBuilderDAG<
18151817
(ins "ValueRange":$inputs, "ValueRange":$outputs),
18161818
[{{
18171819
$_state.addOperands(inputs);
@@ -1826,7 +1828,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18261828
$_state,
18271829
TypeRange(inputs),
18281830
TypeRange(outputs));
1829-
}]>, OpBuilderDAG<
1831+
}]>,
1832+
OpBuilderDAG<
18301833
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
18311834
"ValueRange":$outputs),
18321835
[{{
@@ -1843,7 +1846,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18431846
$_state,
18441847
TypeRange(inputs),
18451848
TypeRange(outputs));
1846-
}]>, OpBuilderDAG<
1849+
}]>,
1850+
OpBuilderDAG<
18471851
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
18481852
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
18491853
[{{
@@ -1852,6 +1856,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18521856
$_state.addTypes(resultTensorTypes);
18531857
(void)$_state.addRegion();
18541858
}]>
1859+
{5}
18551860
];
18561861
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
18571862
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
@@ -1873,8 +1878,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18731878
}];
18741879
})FMT";
18751880

1881+
// Generate documentation.
18761882
std::string doc;
1877-
18781883
if (!docString.empty()) {
18791884
const char *docFmt = R"FMT(
18801885
let summary = [{ {0} }];
@@ -1888,8 +1893,47 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
18881893
doc = llvm::formatv(docFmt, summary.trim(), description.trim());
18891894
}
18901895

1896+
// Generate an additional builder that has parameters for attributes.
1897+
std::string attrBuilder;
1898+
if (!registeredAttrs.empty()) {
1899+
SmallVector<std::string, 4> attrParams, attrStmts;
1900+
for (const auto &attr : registeredAttrs) {
1901+
llvm::StringRef name = attr.first;
1902+
attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name));
1903+
attrStmts.push_back(
1904+
llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name));
1905+
}
1906+
std::string attrParamsList = llvm::join(attrParams, ", ");
1907+
std::string attrStmtsList = llvm::join(attrStmts, "\n");
1908+
1909+
const char *builderFmt = R"FMT(
1910+
, OpBuilderDAG<
1911+
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1912+
"ValueRange":$outputs, {1}),
1913+
[{{
1914+
$_state.addOperands(inputs);
1915+
$_state.addOperands(outputs);
1916+
$_state.addTypes(resultTensorTypes);
1917+
$_state.addAttribute(
1918+
"operand_segment_sizes",
1919+
$_builder.getI32VectorAttr({{
1920+
static_cast<int32_t>(inputs.size()),
1921+
static_cast<int32_t>(outputs.size())}));
1922+
buildNamedStructuredOpRegionAndAttributes<{0}>(
1923+
$_builder,
1924+
$_state,
1925+
TypeRange(inputs),
1926+
TypeRange(outputs));
1927+
{2}
1928+
}]>
1929+
)FMT";
1930+
attrBuilder =
1931+
llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList);
1932+
}
1933+
1934+
// Finally put everything together.
18911935
os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList,
1892-
state.orderedTensorArgs.size());
1936+
state.orderedTensorArgs.size(), attrBuilder);
18931937
}
18941938

18951939
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2146,13 +2190,15 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
21462190
return llvm::formatv("getValue<float>({ {0} })", indexList);
21472191
if (elementType == "i32")
21482192
return llvm::formatv("getValue<int>({ {0} })", indexList);
2193+
if (elementType == "i64")
2194+
return llvm::formatv("getValue<int64_t>({ {0} })", indexList);
21492195

21502196
return "";
21512197
}
21522198

21532199
if (elementType == "f32")
21542200
return "getValue().convertToFloat()";
2155-
if (elementType == "i32")
2201+
if (elementType == "i32" || elementType == "i64")
21562202
return "getInt()";
21572203
return "";
21582204
}

0 commit comments

Comments
 (0)