@@ -60,9 +60,51 @@ struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
6060 }
6161};
6262
63+ struct SelfConcatToTile : public OpRewritePattern <tosa::ConcatOp> {
64+ using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
65+
66+ LogicalResult matchAndRewrite (tosa::ConcatOp concatOp,
67+ PatternRewriter &rewriter) const override {
68+ if (llvm::all_equal (concatOp->getUsers ())) {
69+ const auto concatUser = llvm::dyn_cast<tosa::ConcatOp>(
70+ concatOp->getUses ().begin ()->getOwner ());
71+ if (concatUser) {
72+ // Try folding the concat into its consumer before rewriting it to a
73+ // tile.
74+ SmallVector<Value> replacementValues;
75+ auto foldResult = rewriter.tryFold (concatUser, replacementValues);
76+ if (foldResult.succeeded ()) {
77+ if (!replacementValues.empty ()) {
78+ rewriter.replaceOp (concatUser, replacementValues);
79+ }
80+ return success ();
81+ }
82+ }
83+ }
84+
85+ if (!llvm::all_equal (concatOp->getOperands ())) {
86+ return rewriter.notifyMatchFailure (
87+ concatOp, " Requires all operands to be the same" );
88+ }
89+ const auto concatType = dyn_cast<ShapedType>(concatOp.getType ());
90+ if (!concatType || !concatType.hasRank ()) {
91+ return rewriter.notifyMatchFailure (concatOp,
92+ " Requires concat to be ranked" );
93+ }
94+ SmallVector<int64_t > multiplies (concatType.getRank (), 1 );
95+ multiplies[concatOp.getAxis ()] = concatOp->getNumOperands ();
96+ auto tileOp = rewriter.createOrFold <tosa::TileOp>(
97+ concatOp->getLoc (), concatOp.getType (), concatOp->getOperand (0 ),
98+ multiplies);
99+ rewriter.replaceOp (concatOp, {tileOp});
100+ return success ();
101+ }
102+ };
103+
63104void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
64105 MLIRContext *context) {
65106 results.add <ConcatOptimization>(context);
107+ results.add <SelfConcatToTile>(context);
66108}
67109
68110struct SqrtReciprocalOptimization : public OpRewritePattern <tosa::PowOp> {
@@ -611,42 +653,120 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
611653
612654 llvm::SmallVector<int64_t > sliceStart (sliceOp.getStart ());
613655 llvm::ArrayRef<int64_t > sliceSize = sliceOp.getSize ();
614-
615- // Validate slice on the concatenated axis. Slicing along this
616- // axis should span only one of the inputs to the concatenate
617- // operation.
618- std::optional<Value> replaceWithSlice;
656+ llvm::SmallVector<Value> requiredConcatInputs;
657+ int64_t processedOriginalConcatInputSize = 0 ;
658+ int64_t droppedConcatInputSize = 0 ;
619659 for (auto input : inputs) {
620- auto inputType = dyn_cast<RankedTensorType>(input.getType ());
660+ const auto inputType = dyn_cast<RankedTensorType>(input.getType ());
621661 if (!inputType || !inputType.hasStaticShape ())
622662 return rewriter.notifyMatchFailure (
623663 sliceOp, " concat input must be a static ranked tensor" );
624-
625- if (sliceStart[axis] >= 0 &&
626- (sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize (axis)) {
627- replaceWithSlice = rewriter
628- .create <tosa::SliceOp>(
629- sliceOp.getLoc (), sliceOp.getType (), input,
630- rewriter.getDenseI64ArrayAttr (sliceStart),
631- rewriter.getDenseI64ArrayAttr (sliceSize))
632- .getResult ();
633- break ;
664+ if (processedOriginalConcatInputSize <
665+ (sliceStart[axis] + sliceSize[axis]) &&
666+ (processedOriginalConcatInputSize + inputType.getDimSize (axis)) >
667+ sliceStart[axis]) {
668+ if (requiredConcatInputs.empty ()) {
669+ droppedConcatInputSize = processedOriginalConcatInputSize;
670+ }
671+ requiredConcatInputs.push_back (input);
634672 }
635- sliceStart[axis] -= inputType.getDimSize (axis);
673+ processedOriginalConcatInputSize += inputType.getDimSize (axis);
674+ }
675+ if (requiredConcatInputs.size () == concatOp->getNumOperands ()) {
676+ return rewriter.notifyMatchFailure (
677+ sliceOp, " Could not reduce number of inputs to preceding concat" );
678+ }
679+ if (requiredConcatInputs.size () != 1 && !concatOp->hasOneUse ()) {
680+ return rewriter.notifyMatchFailure (
681+ sliceOp,
682+ " Preceding concat must have a single use" ); // Do not introduce new
683+ // concats
684+ }
685+ if (requiredConcatInputs.empty ()) {
686+ return rewriter.notifyMatchFailure (
687+ sliceOp, " degenerate slice with zero sized dim in output" );
636688 }
689+ sliceStart[axis] -= droppedConcatInputSize;
690+ auto newConcat = rewriter.create <tosa::ConcatOp>(
691+ concatOp->getLoc (), requiredConcatInputs, axis);
692+ auto newSlice = rewriter.create <tosa::SliceOp>(
693+ sliceOp->getLoc (), sliceOp.getType (), newConcat,
694+ rewriter.getDenseI64ArrayAttr (sliceStart),
695+ rewriter.getDenseI64ArrayAttr (sliceSize));
696+ rewriter.replaceOp (sliceOp, newSlice);
697+ return success ();
698+ }
699+ };
700+
701+ // / This patterns adjust the multipliers of a tile followed by a slice to only
702+ // / tile as much data as it is required by the slice
703+ struct TileSliceOptimization : public OpRewritePattern <tosa::SliceOp> {
704+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
705+
706+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
707+ PatternRewriter &rewriter) const override {
708+ Value sliceInput = sliceOp.getInput1 ();
709+ auto tileOp = sliceInput.getDefiningOp <tosa::TileOp>();
710+ if (!tileOp)
711+ return rewriter.notifyMatchFailure (sliceOp,
712+ " slice input must be tile operation" );
713+ if (!tileOp->hasOneUse ())
714+ return rewriter.notifyMatchFailure (
715+ sliceOp, " preceding tile must have a single use" ); // Do not insert
716+ // additional tiles
637717
638- if (!replaceWithSlice)
718+ const auto tileOpInputType =
719+ dyn_cast<RankedTensorType>(tileOp->getOperand (0 ).getType ());
720+ if (!tileOpInputType || !tileOpInputType.hasStaticShape ())
639721 return rewriter.notifyMatchFailure (
640- sliceOp, " corresponding concat input not found for slice" );
722+ sliceOp, " input to preceding tile op must be a static ranked tensor" );
723+ llvm::SmallVector<int64_t > requiredMultipliers;
724+ llvm::SmallVector<int64_t > newTileStarts;
725+ requiredMultipliers.reserve (tileOpInputType.getRank ());
726+ newTileStarts.reserve (tileOpInputType.getRank ());
727+ for (auto [axis, sliceStart, sliceSize] :
728+ llvm::enumerate (sliceOp.getStart (), sliceOp.getSize ())) {
729+ if (sliceSize <= 0 ) {
730+ return rewriter.notifyMatchFailure (
731+ sliceOp, " degenerate slice with zero sized dim" );
732+ }
733+ const int64_t tileInputDimSize = tileOpInputType.getDimSize (axis);
734+ const int64_t sliceOffsetInNewFirstTile = sliceStart % tileInputDimSize;
735+ const int64_t sliceSizeInFirstTile =
736+ std::min (tileInputDimSize - sliceOffsetInNewFirstTile, sliceSize);
737+ assert (sliceSizeInFirstTile > 0 );
738+ const int64_t requiredMultiplierWithoutFirstTile =
739+ llvm::divideCeil (sliceSize - sliceSizeInFirstTile, tileInputDimSize);
740+ const int64_t requiredMultiplier =
741+ requiredMultiplierWithoutFirstTile + (sliceSizeInFirstTile != 0 );
742+ assert (requiredMultiplier <= tileOp.getMultiples ()[axis]);
743+ requiredMultipliers.push_back (requiredMultiplier);
744+ newTileStarts.push_back (sliceOffsetInNewFirstTile);
745+ }
746+ if (requiredMultipliers == tileOp.getMultiples ())
747+ return rewriter.notifyMatchFailure (
748+ sliceOp, " could not reduce multipliers in preceding tile" );
641749
642- rewriter.replaceOp (sliceOp, replaceWithSlice.value ());
750+ llvm::SmallVector<int64_t > newTileShape (tileOpInputType.getShape ());
751+ for (auto [newShape, multiplier] :
752+ llvm::zip_equal (newTileShape, requiredMultipliers)) {
753+ newShape *= multiplier;
754+ }
755+ auto newTile = rewriter.create <tosa::TileOp>(
756+ tileOp->getLoc (), tileOpInputType.clone (newTileShape),
757+ tileOp->getOperand (0 ), requiredMultipliers);
758+ auto newSlice = rewriter.create <tosa::SliceOp>(
759+ sliceOp->getLoc (), sliceOp.getType (), newTile,
760+ rewriter.getDenseI64ArrayAttr (newTileStarts), sliceOp.getSizeAttr ());
761+ rewriter.replaceOp (sliceOp, newSlice);
643762 return success ();
644763 }
645764};
646765
647766void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
648767 MLIRContext *context) {
649768 results.add <ConcatSliceOptimization>(context);
769+ results.add <TileSliceOptimization>(context);
650770}
651771
652772struct MinToClampOptimization : public OpRewritePattern <tosa::MinimumOp> {
@@ -1321,6 +1441,21 @@ OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
13211441 bool allOnes = llvm::all_of (getMultiples (), [](int64_t v) { return v == 1 ; });
13221442 if (allOnes && getInput1 ().getType () == getType ())
13231443 return getInput1 ();
1444+
1445+ if (auto inputTile = getInput1 ().getDefiningOp <TileOp>()) {
1446+ if (!inputTile->hasOneUse ()) {
1447+ return {};
1448+ }
1449+ llvm::SmallVector<int64_t > newMultiplies{getMultiples ()};
1450+ for (auto [idx, multiplier] : llvm::enumerate (inputTile.getMultiples ())) {
1451+ newMultiplies[idx] *= multiplier;
1452+ }
1453+ setMultiples (newMultiplies);
1454+ setOperand (inputTile->getOperand (0 ));
1455+ getOperation ()->setLoc (
1456+ FusedLoc::get (getContext (), {inputTile->getLoc (), getLoc ()}));
1457+ return getResult ();
1458+ }
13241459 return {};
13251460}
13261461
0 commit comments