@@ -2097,6 +2097,44 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
20972097 const bool aggressiveReduceConstant;
20982098};
20992099
2100+ struct TosaFoldConstantConcat : TosaFoldConstantBase<tosa::ConcatOp> {
2101+ using TosaFoldConstantBase::TosaFoldConstantBase;
2102+
2103+ LogicalResult matchAndRewrite (tosa::ConcatOp op,
2104+ PatternRewriter &rewriter) const override {
2105+ auto inputsRange = op.getInput1 ();
2106+ auto axis = op.getAxis ();
2107+
2108+ // TODO: Matching constraints
2109+
2110+ // collect all inputvalues
2111+ SmallVector<DenseElementsAttr> inputValuesArr;
2112+ SmallVector<ShapedType> inputTypesArr;
2113+ for (auto input : inputsRange) {
2114+ DenseElementsAttr inputValues;
2115+ if (!matchPattern (input, m_Constant (&inputValues))) {
2116+ return failure ();
2117+ }
2118+ inputValuesArr.push_back (inputValues);
2119+ inputTypesArr.push_back (cast<ShapedType>(input.getType ()));
2120+ }
2121+
2122+ // compute result
2123+ auto result = this ->concat (inputValuesArr, inputTypesArr, axis);
2124+
2125+ rewriter.replaceOpWithNewOp <tosa::ConstOp>(op, result.first , result.second );
2126+ }
2127+
2128+ std::pair<ShapedType, DenseElementsAttr>
2129+ concat (SmallVector<DenseElementsAttr> &inputValuesArr,
2130+ SmallVector<ShapedType> &inputTypesArr, uint32_t axis) const {
2131+ auto baseType = inputTypesArr[0 ].getElementType ();
2132+ switch (dyn_cast<IntegerType>(baseType).getWidth ()) {
2133+ // TODO:
2134+ }
2135+ }
2136+ }
2137+
21002138} // namespace
21012139
21022140void mlir::tosa::populateTosaFoldConstantPatterns (
@@ -2136,6 +2174,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
21362174 patterns.add <TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly );
21372175 patterns.add <TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly );
21382176 patterns.add <TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly );
2177+ patterns.add <TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly );
21392178 if (options.enableTileFolding )
21402179 patterns.add <TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly );
21412180}
0 commit comments