@@ -167,10 +167,11 @@ exactTilingOnPackUnPackFilter(RewriterBase &rewriter,
167167 tileSizesOnInnerDims =
168168 llvm::to_vector (ArrayRef (tileSizes).take_back (innerTiles.size ()));
169169 } else {
170- // Upstream doesn't implement `getTiledImplementationFromOperandTile`
171- // interface of `packOp` so far. In another word, `packOp` could not be
172- // fused as consumer. As a result, just return failure currently.
173- return failure ();
170+ // tileSize comes from OpOperand
171+ ArrayRef<int64_t > innerDimPos = packOp.getInnerDimsPos ();
172+ for (auto &pos : innerDimPos) {
173+ tileSizesOnInnerDims.push_back (tileSizes[pos]);
174+ }
174175 }
175176 } else if (auto unPackOp = dyn_cast<tensor::UnPackOp>(defOrUse.ownerOp )) {
176177 innerTiles = unPackOp.getMixedTiles ();
@@ -478,8 +479,8 @@ tileAndFuseProducerOfOpOperand(RewriterBase &rewriter, OpOperand &operand,
478479 return std::nullopt ;
479480
480481 // c. Check the producer of root source if is tilable.
481- Operation *producer = realProducer->getDefiningOp <TilingInterface>();
482- if (!producer )
482+ Operation *producerOp = realProducer->getDefiningOp <TilingInterface>();
483+ if (!producerOp )
483484 return std::nullopt ;
484485
485486 CandidateDefOrUse defOrUse{*realProducer};
@@ -536,8 +537,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
536537 SmallVector<scf::SCFFuseConsumerOfSliceResult> fusedResultList;
537538 for (auto useOperand : *realConsumers) {
538539 // c. Check the consumer of top level result if is tilable.
539- Operation *consumer = dyn_cast<TilingInterface>(useOperand->getOwner ());
540- if (!consumer )
540+ Operation *consumerOp = dyn_cast<TilingInterface>(useOperand->getOwner ());
541+ if (!consumerOp )
541542 continue ;
542543
543544 CandidateDefOrUse defOrUse{useOperand};
@@ -559,7 +560,7 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
559560 // f. Manually run cse on region which contains original consumer op in
560561 // avoid of conflict with subsequent `tileAndFuseConsumerOfSlice` get nest
561562 // loops between next candidate sliceOp and tiled producer.
562- (void )mlir::simplifyRegions (rewriter, {*consumer ->getParentRegion ()});
563+ (void )mlir::simplifyRegions (rewriter, {*consumerOp ->getParentRegion ()});
563564 }
564565 }
565566 if (fusedResultList.empty ())
@@ -647,11 +648,18 @@ static LogicalResult isTiledOpInLoop(Operation *targetOp) {
647648
648649using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t >>;
649650
651+ struct defaultTileConfig {
652+ // OpTy-to-TileSize mapping
653+ OpTileSizeMap tsMap;
654+ // ND-tile size
655+ unsigned ndTile;
656+ };
657+
650658// / Default Tiling function only effective for certain `OpTy` operation
651659static FailureOr<scf::SCFTilingResult>
652660defaultTilingOfType (RewriterBase &rewriter, Operation *op,
653661 function_ref<bool (Operation *)> isaOpTy,
654- const OpTileSizeMap &tsMap ) {
662+ const defaultTileConfig &cfg ) {
655663 // a. Check <OpTy>
656664 if (!isa<TilingInterface>(op) || !isaOpTy (op))
657665 return failure ();
@@ -672,18 +680,20 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
672680 // Erase dialect name, such as Linalg or Tensor.
673681 opName.erase (0 , opName.find (" ." ) + 1 );
674682
675- if (tsMap.count (opName)) {
676- SmallVector<int64_t > userDefaultTileSize = tsMap.find (opName)->second ;
683+ if (cfg. tsMap .count (opName)) {
684+ SmallVector<int64_t > userDefaultTileSize = cfg. tsMap .find (opName)->second ;
677685 defaultTileSize =
678686 getAsOpFoldResult (rewriter.getI64ArrayAttr (userDefaultTileSize));
679687 } else {
680688 defaultTileSize.resize (iteratorTypes.size (), rewriter.getIndexAttr (0 ));
681689 // Try tileSize from `32` to `16`.
682690 SmallVector<int64_t > tsOrder = {32 , 16 };
683- // Only 2D tile is expected.
684- int tileDims = (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp (op))
685- ? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops ()
686- : 0 ;
691+ // Record how many dims have been tiled, including fully tiled, i.e.
692+ // tileSize == dimSize.
693+ unsigned nonOneTileDims =
694+ (isa<mlir::linalg::LinalgOp>(op) && !linalgx::isMatmulOp (op))
695+ ? cast<mlir::linalg::LinalgOp>(op).getNumReductionLoops ()
696+ : 0 ;
687697 // Reverse both of iteration type and domain from inner to outer.
688698 std::reverse (iteratorTypes.begin (), iteratorTypes.end ());
689699 std::reverse (iterationDomain.begin (), iterationDomain.end ());
@@ -692,21 +702,29 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
692702 // All parallel iterator will be tiled by `32` or `16`. If need
693703 // specified, please set option `defaultTileSize`, like `matmul:{64,64}`.
694704 if (iterType == utils::IteratorType::parallel) {
695- Range curDomain = iterationDomain[en];
696- std::optional<int64_t > tripCount = mlir::constantTripCount (
697- curDomain.offset , curDomain.size , curDomain.stride );
698- if (tileDims >= 2 && en > 0 ) {
705+ if (nonOneTileDims >= cfg.ndTile && en > 0 ) {
699706 defaultTileSize[en] = rewriter.getIndexAttr (1 );
700707 continue ;
701- } else if (tripCount) {
708+ }
709+ Range curDomain = iterationDomain[en];
710+ if (std::optional<int64_t > tripCount = mlir::constantTripCount (
711+ curDomain.offset , curDomain.size , curDomain.stride )) {
712+ // skip dummy tiling.
713+ if (tripCount == 1 )
714+ continue ;
702715 for (auto &ts : tsOrder) {
703- if (*tripCount % ts == 0 && *tripCount > ts) {
716+ // If `tripCount` equals to `tileSize`, Do NOT explicitly tile it in
717+ // avoid of non-zero offset.
718+ if (*tripCount == ts)
719+ break ;
720+ if (*tripCount % ts == 0 ) {
704721 defaultTileSize[en] = rewriter.getIndexAttr (ts);
705722 break ;
706723 }
707724 }
708725 }
709- tileDims++;
726+ // Fallback to fully tiled.
727+ nonOneTileDims++;
710728 }
711729 }
712730 }
@@ -731,7 +749,7 @@ defaultTilingOfType(RewriterBase &rewriter, Operation *op,
731749
732750void iterativeTilingAndFusionUntilExhaustion (
733751 RewriterBase &rewriter, func::FuncOp &f,
734- const CandidateSliceOptions &sliceOptions, const OpTileSizeMap &tsMap ) {
752+ const CandidateSliceOptions &sliceOptions, const defaultTileConfig &cfg ) {
735753 // Collect untiled and tiled ops respectively
736754 llvm::SetVector<Operation *> tiledOps, unTiledOps;
737755
@@ -799,7 +817,7 @@ void iterativeTilingAndFusionUntilExhaustion(
799817 for (auto &isaOpTy : priorityOpTypeOrder) {
800818 for (auto &op : unTiledOps) {
801819 FailureOr<scf::SCFTilingResult> tilingResult =
802- defaultTilingOfType (rewriter, op, isaOpTy, tsMap );
820+ defaultTilingOfType (rewriter, op, isaOpTy, cfg );
803821 if (succeeded (tilingResult)) {
804822 tiledOps.insert (tilingResult->tiledOps [0 ]);
805823 rewriter.replaceOp (op, tilingResult->replacements );
@@ -881,8 +899,8 @@ struct IterativeTilingAndFusion
881899 // Get rewriter
882900 IRRewriter rewriter (&ctx);
883901 // Run iterative fusion
884- iterativeTilingAndFusionUntilExhaustion (rewriter, func, sliceOptions,
885- tsMap);
902+ iterativeTilingAndFusionUntilExhaustion (
903+ rewriter, func, sliceOptions, defaultTileConfig{ tsMap, defaultNDTile} );
886904 }
887905};
888906
0 commit comments