Skip to content

Commit 43dc066

Browse files
committed
wip: constant folding for concat
1 parent 20ead4c commit 43dc066

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,44 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
20972097
const bool aggressiveReduceConstant;
20982098
};
20992099

2100+
struct TosaFoldConstantConcat : TosaFoldConstantBase<tosa::ConcatOp> {
2101+
using TosaFoldConstantBase::TosaFoldConstantBase;
2102+
2103+
LogicalResult matchAndRewrite(tosa::ConcatOp op,
2104+
PatternRewriter &rewriter) const override {
2105+
auto inputsRange = op.getInput1();
2106+
auto axis = op.getAxis();
2107+
2108+
// TODO: Matching constraints
2109+
2110+
// collect all inputvalues
2111+
SmallVector<DenseElementsAttr> inputValuesArr;
2112+
SmallVector<ShapedType> inputTypesArr;
2113+
for (auto input : inputsRange) {
2114+
DenseElementsAttr inputValues;
2115+
if (!matchPattern(input, m_Constant(&inputValues))) {
2116+
return failure();
2117+
}
2118+
inputValuesArr.push_back(inputValues);
2119+
inputTypesArr.push_back(cast<ShapedType>(input.getType()));
2120+
}
2121+
2122+
// compute result
2123+
auto result = this->concat(inputValuesArr, inputTypesArr, axis);
2124+
2125+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, result.first, result.second);
2126+
}
2127+
2128+
std::pair<ShapedType, DenseElementsAttr>
2129+
concat(SmallVector<DenseElementsAttr> &inputValuesArr,
2130+
SmallVector<ShapedType> &inputTypesArr, uint32_t axis) const {
2131+
auto baseType = inputTypesArr[0].getElementType();
2132+
switch (dyn_cast<IntegerType>(baseType).getWidth()) {
2133+
// TODO:
2134+
}
2135+
}
2136+
}
2137+
21002138
} // namespace
21012139

21022140
void mlir::tosa::populateTosaFoldConstantPatterns(
@@ -2136,6 +2174,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
21362174
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
21372175
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
21382176
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
2177+
patterns.add<TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly);
21392178
if (options.enableTileFolding)
21402179
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
21412180
}

0 commit comments

Comments
 (0)