@@ -56,10 +56,21 @@ namespace {
5656// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757//
5858struct TileCheck : public AffineExprVisitor <TileCheck> {
59- TileCheck (ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59+ TileCheck (ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t > domainSizes)
60+ : tileSizes(tileSizes), domainSizes(domainSizes) {}
6061
6162 void visitDimExpr (AffineDimExpr expr) {
62- isTiled |= !isZeroIndex (tileSizes[expr.getPosition ()]);
63+ unsigned pos = expr.getPosition ();
64+
65+ // There is no tile if all tile sizes correspond to the domain size
66+ std::optional<int64_t > tileSize = getConstantIntValue (tileSizes[pos]);
67+ if (tileSize && !domainSizes.empty ()) {
68+ if (domainSizes[pos] == *tileSize) {
69+ return ;
70+ }
71+ }
72+
73+ isTiled |= !isZeroIndex (tileSizes[pos]);
6374 }
6475 void visitAffineBinaryOpExpr (AffineBinaryOpExpr expr) {
6576 visit (expr.getLHS ());
@@ -70,24 +81,28 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
7081 }
7182 bool isTiled = false ;
7283 ArrayRef<OpFoldResult> tileSizes;
84+ ArrayRef<int64_t > domainSizes;
7385};
7486
7587} // namespace
7688
77- static bool isTiled (AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
89+ static bool isTiled (AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
90+ ArrayRef<int64_t > domainSizes) {
7891 if (!expr)
7992 return false ;
80- TileCheck t (tileSizes);
93+
94+ TileCheck t (tileSizes, domainSizes);
8195 t.visit (expr);
8296 return t.isTiled ;
8397}
8498
8599// Checks whether the `map varies with respect to a non-zero `tileSize`.
86- static bool isTiled (AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
100+ static bool isTiled (AffineMap map, ArrayRef<OpFoldResult> tileSizes,
101+ ArrayRef<int64_t > domainSizes) {
87102 if (!map)
88103 return false ;
89104 for (unsigned r = 0 ; r < map.getNumResults (); ++r)
90- if (isTiled (map.getResult (r), tileSizes))
105+ if (isTiled (map.getResult (r), tileSizes, domainSizes ))
91106 return true ;
92107 return false ;
93108}
@@ -556,19 +571,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
556571 ArrayRef<OpFoldResult> lbs,
557572 ArrayRef<OpFoldResult> ubs,
558573 ArrayRef<OpFoldResult> subShapeSizes,
559- bool omitPartialTileCheck) {
560- SliceParameters sliceParams =
561- computeSliceParameters (builder, loc, valueToTile, tileSizes, map, lbs,
562- ubs, subShapeSizes, omitPartialTileCheck);
574+ bool omitPartialTileCheck,
575+ ArrayRef<int64_t > domainSizes) {
576+ SliceParameters sliceParams = computeSliceParameters (
577+ builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
578+ omitPartialTileCheck, domainSizes);
563579 return materializeTiledShape (builder, loc, valueToTile, sliceParams);
564580}
565581
566- SliceParameters
567- computeSliceParameters (OpBuilder &builder, Location loc, Value valueToTile,
568- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
569- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
570- ArrayRef<OpFoldResult> subShapeSizes,
571- bool omitPartialTileCheck) {
582+ SliceParameters computeSliceParameters (
583+ OpBuilder &builder, Location loc, Value valueToTile,
584+ ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
585+ ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
586+ bool omitPartialTileCheck, ArrayRef<int64_t > domainSizes) {
572587 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType ());
573588 assert (shapedType && " only shaped types can be tiled" );
574589 ArrayRef<int64_t > shape = shapedType.getShape ();
@@ -585,7 +600,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
585600 // The offset & size computation below only handles the case when
586601 // the map is monotonically increasing, i.e. the min and max values are
587602 // attained at the lower and upper bounds of the iteration domain.
588- if (!isTiled (m, tileSizes) || !m. isComponentWiseMonotonicallyIncreasing ( )) {
603+ if (!isTiled (m, tileSizes, domainSizes )) {
589604 sliceParams.offsets .push_back (builder.getIndexAttr (0 ));
590605 OpFoldResult dim = createFoldedDimOp (builder, loc, valueToTile, r);
591606 sliceParams.sizes .push_back (dim);
@@ -786,8 +801,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
786801 // subdomains explicit.
787802
788803 Type operandType = opOperand.get ().getType ();
789- if (!isTiled (map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
790- linalgOp.isDpsInit (&opOperand))) {
804+ if (!isTiled (map, tileSizes, linalgOp.getStaticLoopRanges ()) &&
805+ !(isa<RankedTensorType>(operandType) &&
806+ linalgOp.isDpsInit (&opOperand))) {
791807 allSliceParams.push_back (std::nullopt );
792808 LLVM_DEBUG (llvm::dbgs ()
793809 << " : not tiled: use shape: " << operandType << " \n " );
@@ -797,7 +813,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
797813
798814 allSliceParams.push_back (computeSliceParameters (
799815 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
800- omitPartialTileCheck));
816+ omitPartialTileCheck, linalgOp. getStaticLoopRanges () ));
801817 }
802818
803819 return allSliceParams;
0 commit comments