Skip to content

Commit 797b0dd

Browse files
authored
Merge pull request #586 from Xilinx/jrickert.fold_tosa_concat
Add folding for tosa.concat if all inputs are constant foldable
2 parents 20ead4c + 38fc7c0 commit 797b0dd

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,120 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
20972097
const bool aggressiveReduceConstant;
20982098
};
20992099

2100+
template <typename ElementStorageType>
2101+
DenseElementsAttr
2102+
concatenateAttrs(const ShapedType outputType, ArrayRef<ElementsAttr> inputAttrs,
2103+
const uint32_t concatAxis, PatternRewriter &rewriter,
2104+
const Type elementType) {
2105+
2106+
static_assert(std::is_same<ElementStorageType, APInt>::value ||
2107+
std::is_same<ElementStorageType, APFloat>::value,
2108+
"ElementStorageType must be either APInt or APFloat");
2109+
2110+
SmallVector<ElementStorageType> resultValues;
2111+
if constexpr (std::is_same<ElementStorageType, APInt>::value) {
2112+
resultValues.resize_for_overwrite(outputType.getNumElements());
2113+
} else {
2114+
resultValues.resize(
2115+
outputType.getNumElements(),
2116+
APFloat::getZero(cast<FloatType>(elementType).getFloatSemantics()));
2117+
}
2118+
const auto outputShape = outputType.getShape();
2119+
2120+
int64_t concatDimOffset = 0;
2121+
for (const auto &inputAttr : inputAttrs) {
2122+
const auto inputShape = cast<ShapedType>(inputAttr.getType()).getShape();
2123+
const auto inputValues = inputAttr.getValues<ElementStorageType>();
2124+
2125+
for (const auto &[inputLinearIdx, val] : llvm::enumerate(inputValues)) {
2126+
// TODO: Could be optimized to work on slices instead of single value
2127+
SmallVector<int64_t> multiDimIndex =
2128+
offsetToIndex(inputShape, inputLinearIdx);
2129+
multiDimIndex[concatAxis] += concatDimOffset;
2130+
2131+
const int64_t outputLinearIndex =
2132+
indexToOffset(outputShape, multiDimIndex);
2133+
resultValues[outputLinearIndex] = val;
2134+
}
2135+
concatDimOffset += inputShape[concatAxis];
2136+
}
2137+
return DenseElementsAttr::get(outputType, resultValues);
2138+
}
2139+
2140+
struct TosaFoldConstantConcat : public TosaFoldConstantBase<tosa::ConcatOp> {
2141+
using TosaFoldConstantBase::TosaFoldConstantBase;
2142+
2143+
LogicalResult matchAndRewrite(tosa::ConcatOp op,
2144+
PatternRewriter &rewriter) const override {
2145+
auto inputs = op->getOperands();
2146+
const uint32_t concatAxis = op.getAxis();
2147+
const auto outputType = cast<ShapedType>(op.getType());
2148+
if (!outputType.hasStaticShape()) {
2149+
return rewriter.notifyMatchFailure(
2150+
op, "Output type must have static shape for concat folding.");
2151+
}
2152+
if (llvm::any_of(inputs, [](Value v) {
2153+
return !cast<ShapedType>(v.getType()).hasStaticShape();
2154+
})) {
2155+
return rewriter.notifyMatchFailure(
2156+
op, "All inputs to ConcatOp must have static shape for folding.");
2157+
}
2158+
2159+
const Type elementType = outputType.getElementType();
2160+
if (!elementType.isIntOrIndexOrFloat()) {
2161+
// Sanity check, this should always be the case
2162+
return rewriter.notifyMatchFailure(
2163+
op, "Output element type must be int, index, or float for folding.");
2164+
}
2165+
2166+
SmallVector<ElementsAttr> inputAttrs;
2167+
inputAttrs.reserve(inputs.size());
2168+
2169+
for (Value inputVal : inputs) {
2170+
ElementsAttr inputAsAttr;
2171+
if (!matchPattern(inputVal, m_Constant(&inputAsAttr))) {
2172+
// TODO: This could be extended to handle partial non-const inputs
2173+
return rewriter.notifyMatchFailure(
2174+
op, "All inputs to ConcatOp must be constant for folding.");
2175+
}
2176+
2177+
if (inputAsAttr.isSplat()) {
2178+
const ShapedType inputType = cast<ShapedType>(inputAsAttr.getType());
2179+
if (isa<IntegerType>(elementType)) {
2180+
inputAsAttr = DenseElementsAttr::get(
2181+
inputType, inputAsAttr.getSplatValue<APInt>());
2182+
} else {
2183+
inputAsAttr = DenseElementsAttr::get(
2184+
inputType, inputAsAttr.getSplatValue<APFloat>());
2185+
}
2186+
}
2187+
if (foldSplatOrSingleUseOnly && !inputVal.hasOneUse() &&
2188+
!inputAsAttr.isSplat()) {
2189+
return rewriter.notifyMatchFailure(
2190+
op, "Concat folding heuristic: non-splat constant inputs must have "
2191+
"only a single use.");
2192+
}
2193+
inputAttrs.push_back(inputAsAttr);
2194+
}
2195+
2196+
DenseElementsAttr resultAttr;
2197+
if (auto intType = dyn_cast<IntegerType>(elementType)) {
2198+
// TODO: This could be optimized to not go to APInt if the int size
2199+
// matches c++ native types
2200+
resultAttr = concatenateAttrs<APInt>(outputType, inputAttrs, concatAxis,
2201+
rewriter, elementType);
2202+
} else {
2203+
resultAttr = concatenateAttrs<APFloat>(outputType, inputAttrs, concatAxis,
2204+
rewriter, elementType);
2205+
}
2206+
2207+
assert(resultAttr && "Result attribute should not be null.");
2208+
2209+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
2210+
return success();
2211+
}
2212+
};
2213+
21002214
} // namespace
21012215

21022216
void mlir::tosa::populateTosaFoldConstantPatterns(
@@ -2136,6 +2250,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
21362250
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
21372251
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
21382252
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
2253+
patterns.add<TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly);
21392254
if (options.enableTileFolding)
21402255
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
21412256
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// RUN: mlir-opt --tosa-layerwise-constant-fold %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @concat_i32_axis0
4+
// CHECK-SAME: () -> tensor<4x2xi32> {
5+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2], [3, 4], [5, 6], [7, 8]{{.}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32>
6+
// CHECK: return [[VAR_0_]] : tensor<4x2xi32>
7+
func.func @concat_i32_axis0() -> (tensor<4x2xi32>) {
8+
%c0 = "tosa.const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
9+
%c1 = "tosa.const"() {value = dense<[[5, 6], [7, 8]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
10+
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32>
11+
return %0 : tensor<4x2xi32>
12+
}
13+
14+
// CHECK-LABEL: func.func @concat_f32_axis1
15+
// CHECK-SAME: () -> tensor<2x3xf32> {
16+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]{{.}}> : tensor<2x3xf32>}> : () -> tensor<2x3xf32>
17+
// CHECK: return [[VAR_0_]] : tensor<2x3xf32>
18+
func.func @concat_f32_axis1() -> (tensor<2x3xf32>) {
19+
%c0 = "tosa.const"() {value = dense<[[1.0, 2.0], [4.0, 5.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
20+
%c1 = "tosa.const"() {value = dense<[[3.0], [6.0]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
21+
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
22+
return %0 : tensor<2x3xf32>
23+
}
24+
25+
// CHECK-LABEL: func.func @concat_i8_three_inputs_axis1
26+
// CHECK-SAME: () -> tensor<1x5xi8> {
27+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2, 3, 4, 5]{{.}}> : tensor<1x5xi8>}> : () -> tensor<1x5xi8>
28+
// CHECK: return [[VAR_0_]] : tensor<1x5xi8>
29+
func.func @concat_i8_three_inputs_axis1() -> (tensor<1x5xi8>) {
30+
%c0 = "tosa.const"() {value = dense<[[1, 2]]> : tensor<1x2xi8>} : () -> tensor<1x2xi8>
31+
%c1 = "tosa.const"() {value = dense<[[3]]> : tensor<1x1xi8>} : () -> tensor<1x1xi8>
32+
%c2 = "tosa.const"() {value = dense<[[4, 5]]> : tensor<1x2xi8>} : () -> tensor<1x2xi8>
33+
%0 = "tosa.concat"(%c0, %c1, %c2) {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x1xi8>, tensor<1x2xi8>) -> tensor<1x5xi8>
34+
return %0 : tensor<1x5xi8>
35+
}
36+
37+
// CHECK-LABEL: func.func @concat_i32_with_splat_axis0
38+
// CHECK-SAME: () -> tensor<3x1xi32> {
39+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[7], [7], [8]{{.}}> : tensor<3x1xi32>}> : () -> tensor<3x1xi32>
40+
// CHECK: return [[VAR_0_]] : tensor<3x1xi32>
41+
func.func @concat_i32_with_splat_axis0() -> (tensor<3x1xi32>) {
42+
%c0 = "tosa.const"() {value = dense<7> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
43+
%c1 = "tosa.const"() {value = dense<[[8]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
44+
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x1xi32>, tensor<1x1xi32>) -> tensor<3x1xi32>
45+
return %0 : tensor<3x1xi32>
46+
}
47+
48+
// CHECK-LABEL: func.func @concat_bool_axis0
49+
// CHECK-SAME: () -> tensor<2x2xi1> {
50+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[true, false], [false, true]{{.}}> : tensor<2x2xi1>}> : () -> tensor<2x2xi1>
51+
// CHECK: return [[VAR_0_]] : tensor<2x2xi1>
52+
func.func @concat_bool_axis0() -> (tensor<2x2xi1>) {
53+
%c0 = "tosa.const"() {value = dense<[[true], [false]]> : tensor<2x1xi1>} : () -> tensor<2x1xi1>
54+
%c1 = "tosa.const"() {value = dense<[[false], [true]]> : tensor<2x1xi1>} : () -> tensor<2x1xi1>
55+
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x1xi1>, tensor<2x1xi1>) -> tensor<2x2xi1>
56+
return %0 : tensor<2x2xi1>
57+
}
58+
59+
// CHECK-LABEL: func.func @concat_rank1_i32_axis0
60+
// CHECK-SAME: () -> tensor<5xi32> {
61+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>}> : () -> tensor<5xi32>
62+
// CHECK: return [[VAR_0_]] : tensor<5xi32>
63+
func.func @concat_rank1_i32_axis0() -> (tensor<5xi32>) {
64+
%c0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
65+
%c1 = "tosa.const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> tensor<2xi32>
66+
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
67+
return %0 : tensor<5xi32>
68+
}
69+
70+
// CHECK-LABEL: func.func @concat_empty_tensor_axis0
71+
// CHECK-SAME: () -> tensor<2x2xi32> {
72+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2], [3, 4]{{.}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
73+
// CHECK: return [[VAR_0_]] : tensor<2x2xi32>
74+
func.func @concat_empty_tensor_axis0() -> (tensor<2x2xi32>) {
75+
%c0 = "tosa.const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
76+
%c1 = "tosa.const"() {value = dense<> : tensor<0x2xi32>} : () -> tensor<0x2xi32>
77+
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xi32>, tensor<0x2xi32>) -> tensor<2x2xi32>
78+
return %0 : tensor<2x2xi32>
79+
}
80+
81+
// CHECK-LABEL: func.func @concat_all_empty_tensors_axis1
82+
// CHECK-SAME: () -> tensor<2x0xi32> {
83+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<> : tensor<2x0xi32>}> : () -> tensor<2x0xi32>
84+
// CHECK: return [[VAR_0_]] : tensor<2x0xi32>
85+
func.func @concat_all_empty_tensors_axis1() -> (tensor<2x0xi32>) {
86+
%c0 = "tosa.const"() {value = dense<> : tensor<2x0xi32>} : () -> tensor<2x0xi32>
87+
%c1 = "tosa.const"() {value = dense<> : tensor<2x0xi32>} : () -> tensor<2x0xi32>
88+
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x0xi32>, tensor<2x0xi32>) -> tensor<2x0xi32>
89+
return %0 : tensor<2x0xi32>
90+
}
91+
92+
// CHECK-LABEL: func.func @concat_i32_axis1_three_inputs_two_splats
93+
// CHECK-SAME: () -> tensor<2x4xi32> {
94+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 10, 11, 2], [1, 12, 13, 2]{{.}}> : tensor<2x4xi32>}> : () -> tensor<2x4xi32>
95+
// CHECK: return [[VAR_0_]] : tensor<2x4xi32>
96+
func.func @concat_i32_axis1_three_inputs_two_splats() -> (tensor<2x4xi32>) {
97+
%c0_splat = "tosa.const"() {value = dense<1> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
98+
%c1_dense = "tosa.const"() {value = dense<[[10, 11], [12, 13]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
99+
%c2_splat = "tosa.const"() {value = dense<2> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
100+
%0 = "tosa.concat"(%c0_splat, %c1_dense, %c2_splat) {axis = 1 : i32} : (tensor<2x1xi32>, tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2x4xi32>
101+
return %0 : tensor<2x4xi32>
102+
}
103+
104+
// CHECK-LABEL: func.func @concat_ui16_axis0
105+
// CHECK-SAME: () -> tensor<3x2xui16> {
106+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[100, 200], [300, 400], [500, 600]{{.}}> : tensor<3x2xui16>}> : () -> tensor<3x2xui16>
107+
// CHECK: return [[VAR_0_]] : tensor<3x2xui16>
108+
func.func @concat_ui16_axis0() -> (tensor<3x2xui16>) {
109+
%c0 = "tosa.const"() {value = dense<[[100, 200], [300, 400]]> : tensor<2x2xui16>} : () -> tensor<2x2xui16>
110+
%c1 = "tosa.const"() {value = dense<[[500, 600]]> : tensor<1x2xui16>} : () -> tensor<1x2xui16>
111+
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xui16>, tensor<1x2xui16>) -> tensor<3x2xui16>
112+
return %0 : tensor<3x2xui16>
113+
}
114+
115+
// CHECK-LABEL: func.func @concat_3d_bf16_axis1
116+
// CHECK-SAME: () -> tensor<2x3x2xbf16> {
117+
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}{{.}}[1.000000e+00, 2.000000e+00], [5.000000e+00, 6.000000e+00], [7.000000e+00, 8.000000e+00]{{.}}, {{.}}[3.000000e+00, 4.000000e+00], [9.000000e+00, 1.000000e+01], [1.100000e+01, 1.200000e+01]{{.}}{{.}}> : tensor<2x3x2xbf16>}> : () -> tensor<2x3x2xbf16>
118+
// CHECK: return [[VAR_0_]] : tensor<2x3x2xbf16>
119+
func.func @concat_3d_bf16_axis1() -> (tensor<2x3x2xbf16>) {
120+
%c0 = "tosa.const"() {value = dense<[[[1.0, 2.0]], [[3.0, 4.0]]]> : tensor<2x1x2xbf16>} : () -> tensor<2x1x2xbf16>
121+
%c1 = "tosa.const"() {value = dense<[[[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]]]> : tensor<2x2x2xbf16>} : () -> tensor<2x2x2xbf16>
122+
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x1x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x3x2xbf16>
123+
return %0 : tensor<2x3x2xbf16>
124+
}

0 commit comments

Comments
 (0)