@@ -724,6 +724,25 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
724724 applyPermutationToVector<OpFoldResult>(sizes, permutation);
725725}
726726
727+ static FailureOr<int64_t > getStaticPartOfScalableTileSize (Operation *op) {
728+ auto mulIOp = dyn_cast<arith::MulIOp>(op);
729+ if (!mulIOp)
730+ return failure ();
731+
732+ auto lhs = mulIOp.getLhs ().getDefiningOp ();
733+ auto rhs = mulIOp.getRhs ().getDefiningOp ();
734+
735+ auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs)
736+ : dyn_cast<arith::ConstantOp>(rhs);
737+ if (!cstOp)
738+ return failure ();
739+ if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs))
740+ return failure ();
741+ if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue ()))
742+ return integerAttr.getInt ();
743+ return failure ();
744+ }
745+
727746struct PackOpTiling
728747 : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
729748
@@ -950,7 +969,30 @@ struct PackOpTiling
950969 // another word, we can only support tiling with consumer if the tile
951970 // size for the producer is a multiple of the inner tile size for the
952971 // packed dimensions at this moment.
953- if ((failed (cstTileSize) || !cstInnerSize ||
972+ bool assumeInnerTileSizesMatchTiles = false ;
973+ bool isAlignedToInnerTileSize = false ;
974+ if (!cstInnerSize) {
975+ // Warning: Hard-coded separation of distribution and vector level tiling :D
976+ Value scalableInnerTileSize = cast<Value>(dimAndTileMapping[dim]);
977+ auto staticInnerTileSize =
978+ getStaticPartOfScalableTileSize (scalableInnerTileSize.getDefiningOp ());
979+ isAlignedToInnerTileSize = succeeded (staticInnerTileSize);
980+ if (auto tileSizeVal = dyn_cast<Value>(sizes[dim])) {
981+ if (auto tileSizeAffineOp =
982+ tileSizeVal.getDefiningOp <affine::AffineMinOp>()) {
983+ if (tileSizeAffineOp->getNumOperands () >= 2 ) {
984+ Value scalableTileSize = tileSizeAffineOp->getOperand (tileSizeAffineOp->getNumOperands () - 1 );
985+ auto staticTileSize =
986+ getStaticPartOfScalableTileSize (scalableTileSize.getDefiningOp ());
987+ assumeInnerTileSizesMatchTiles =
988+ succeeded (staticTileSize) && succeeded (staticInnerTileSize) &&
989+ staticTileSize.value () == staticInnerTileSize.value ();
990+ }
991+ }
992+ }
993+ }
994+ // TODO: here comes the alignment flag
995+ if (!isAlignedToInnerTileSize && (failed (cstTileSize) || !cstInnerSize ||
954996 *cstTileSize % *cstInnerSize != 0 ))
955997 return failure ();
956998
@@ -963,6 +1005,7 @@ struct PackOpTiling
9631005 auto avSize = AV (dim0).bind (sizes[dim]);
9641006 auto avTileSize = AV (sym).bind (dimAndTileMapping[dim]);
9651007 outerDimOffsets.push_back (ab.floor (avOffset, avTileSize));
1008+ // TODO: here comes the equality one and 1 outer dim :)
9661009 outerDimSizes.push_back (ab.ceil (avSize, avTileSize));
9671010 } else {
9681011 outerDimOffsets.push_back (offsets[dim]);
@@ -1040,25 +1083,6 @@ struct UnpackTileDimInfo {
10401083 OpFoldResult destExpandedSize;
10411084};
10421085
1043- FailureOr<int64_t > getStaticPartOfScalableTileSize (Operation *op) {
1044- auto mulIOp = dyn_cast<arith::MulIOp>(op);
1045- if (!mulIOp)
1046- return failure ();
1047-
1048- auto lhs = mulIOp.getLhs ().getDefiningOp ();
1049- auto rhs = mulIOp.getRhs ().getDefiningOp ();
1050-
1051- auto cstOp = isa<arith::ConstantOp>(lhs) ? cast<arith::ConstantOp>(lhs)
1052- : dyn_cast<arith::ConstantOp>(rhs);
1053- if (!cstOp)
1054- return failure ();
1055- if (!isa<vector::VectorScaleOp>(lhs) && !isa<vector::VectorScaleOp>(rhs))
1056- return failure ();
1057- if (auto integerAttr = dyn_cast<IntegerAttr>(cstOp.getValue ()))
1058- return integerAttr.getInt ();
1059- return failure ();
1060- }
1061-
10621086// / Returns the needed information for tiling unpack op on `tileDim` with given
10631087// / `tileOffset` and `tileSize`. For more details, see the comment of the
10641088// / `getTiledImplementation`.
0 commit comments