@@ -1768,6 +1768,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1768
1768
std::string odsType = llvm::StringSwitch<std::string>(elementType)
1769
1769
.Case (" f32" , " F32" )
1770
1770
.Case (" i32" , " I32" )
1771
+ .Case (" i64" , " I64" )
1771
1772
.Default (" " );
1772
1773
if (odsType.empty ()) {
1773
1774
parser.emitError (" unimplemented support for attribute element type: " +
@@ -1811,7 +1812,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1811
1812
let regions = (region AnyRegion:$region);
1812
1813
1813
1814
let skipDefaultBuilders = 1;
1814
- let builders = [ OpBuilderDAG<
1815
+ let builders = [
1816
+ OpBuilderDAG<
1815
1817
(ins "ValueRange":$inputs, "ValueRange":$outputs),
1816
1818
[{{
1817
1819
$_state.addOperands(inputs);
@@ -1826,7 +1828,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1826
1828
$_state,
1827
1829
TypeRange(inputs),
1828
1830
TypeRange(outputs));
1829
- }]>, OpBuilderDAG<
1831
+ }]>,
1832
+ OpBuilderDAG<
1830
1833
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1831
1834
"ValueRange":$outputs),
1832
1835
[{{
@@ -1843,7 +1846,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1843
1846
$_state,
1844
1847
TypeRange(inputs),
1845
1848
TypeRange(outputs));
1846
- }]>, OpBuilderDAG<
1849
+ }]>,
1850
+ OpBuilderDAG<
1847
1851
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
1848
1852
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
1849
1853
[{{
@@ -1852,6 +1856,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1852
1856
$_state.addTypes(resultTensorTypes);
1853
1857
(void)$_state.addRegion();
1854
1858
}]>
1859
+ {5}
1855
1860
];
1856
1861
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
1857
1862
let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
@@ -1873,8 +1878,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1873
1878
}];
1874
1879
})FMT" ;
1875
1880
1881
+ // Generate documentation.
1876
1882
std::string doc;
1877
-
1878
1883
if (!docString.empty ()) {
1879
1884
const char *docFmt = R"FMT(
1880
1885
let summary = [{ {0} }];
@@ -1888,8 +1893,47 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1888
1893
doc = llvm::formatv (docFmt, summary.trim (), description.trim ());
1889
1894
}
1890
1895
1896
+ // Generate an additional builder that has parameters for attributes.
1897
+ std::string attrBuilder;
1898
+ if (!registeredAttrs.empty ()) {
1899
+ SmallVector<std::string, 4 > attrParams, attrStmts;
1900
+ for (const auto &attr : registeredAttrs) {
1901
+ llvm::StringRef name = attr.first ;
1902
+ attrParams.push_back (llvm::formatv (" \" Attribute\" :${0}" , name));
1903
+ attrStmts.push_back (
1904
+ llvm::formatv (" $_state.addAttribute(\" {0}\" , {0});" , name));
1905
+ }
1906
+ std::string attrParamsList = llvm::join (attrParams, " , " );
1907
+ std::string attrStmtsList = llvm::join (attrStmts, " \n " );
1908
+
1909
+ const char *builderFmt = R"FMT(
1910
+ , OpBuilderDAG<
1911
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
1912
+ "ValueRange":$outputs, {1}),
1913
+ [{{
1914
+ $_state.addOperands(inputs);
1915
+ $_state.addOperands(outputs);
1916
+ $_state.addTypes(resultTensorTypes);
1917
+ $_state.addAttribute(
1918
+ "operand_segment_sizes",
1919
+ $_builder.getI32VectorAttr({{
1920
+ static_cast<int32_t>(inputs.size()),
1921
+ static_cast<int32_t>(outputs.size())}));
1922
+ buildNamedStructuredOpRegionAndAttributes<{0}>(
1923
+ $_builder,
1924
+ $_state,
1925
+ TypeRange(inputs),
1926
+ TypeRange(outputs));
1927
+ {2}
1928
+ }]>
1929
+ )FMT" ;
1930
+ attrBuilder =
1931
+ llvm::formatv (builderFmt, cppOpName, attrParamsList, attrStmtsList);
1932
+ }
1933
+
1934
+ // Finally put everything together.
1891
1935
os << llvm::formatv (header, cppOpName, linalgOpName, doc, attrList,
1892
- state.orderedTensorArgs .size ());
1936
+ state.orderedTensorArgs .size (), attrBuilder );
1893
1937
}
1894
1938
1895
1939
// / Print the C++ StructuredOpsInterface impl of `iterator_types`.
@@ -2146,13 +2190,15 @@ TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
2146
2190
return llvm::formatv (" getValue<float>({ {0} })" , indexList);
2147
2191
if (elementType == " i32" )
2148
2192
return llvm::formatv (" getValue<int>({ {0} })" , indexList);
2193
+ if (elementType == " i64" )
2194
+ return llvm::formatv (" getValue<int64_t>({ {0} })" , indexList);
2149
2195
2150
2196
return " " ;
2151
2197
}
2152
2198
2153
2199
if (elementType == " f32" )
2154
2200
return " getValue().convertToFloat()" ;
2155
- if (elementType == " i32" )
2201
+ if (elementType == " i32" || elementType == " i64 " )
2156
2202
return " getInt()" ;
2157
2203
return " " ;
2158
2204
}
0 commit comments