@@ -56,17 +56,18 @@ namespace {
5656// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757//
5858struct TileCheck : public AffineExprVisitor <TileCheck> {
59- TileCheck (ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t > domainSizes )
60- : tileSizes(tileSizes), domainSizes(domainSizes ) {}
59+ TileCheck (ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds )
60+ : tileSizes(tileSizes), sizeBounds(sizeBounds ) {}
6161
6262 void visitDimExpr (AffineDimExpr expr) {
6363 unsigned pos = expr.getPosition ();
6464
6565 // This dimension is tiled if the tile size is larger than zero and not
6666 // equal to its domain size (if statically known).
6767 std::optional<int64_t > tileSize = getConstantIntValue (tileSizes[pos]);
68- if (tileSize && !domainSizes.empty ()) {
69- if (domainSizes[pos] == *tileSize) {
68+ if (tileSize && !sizeBounds.empty ()) {
69+ std::optional<int64_t > sizeBound = getConstantIntValue (sizeBounds[pos]);
70+ if (sizeBound && *sizeBound == *tileSize) {
7071 return ;
7172 }
7273 }
@@ -82,27 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
8283 }
8384 bool isTiled = false ;
8485 ArrayRef<OpFoldResult> tileSizes;
85- ArrayRef<int64_t > domainSizes ;
86+ ArrayRef<OpFoldResult> sizeBounds ;
8687};
8788
8889} // namespace
8990
9091static bool isTiled (AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
91- ArrayRef<int64_t > domainSizes ) {
92+ ArrayRef<OpFoldResult> sizeBounds ) {
9293 if (!expr)
9394 return false ;
94- TileCheck t (tileSizes, domainSizes );
95+ TileCheck t (tileSizes, sizeBounds );
9596 t.visit (expr);
9697 return t.isTiled ;
9798}
9899
99100// Checks whether the `map varies with respect to a non-zero `tileSize`.
100101static bool isTiled (AffineMap map, ArrayRef<OpFoldResult> tileSizes,
101- ArrayRef<int64_t > domainSizes ) {
102+ ArrayRef<OpFoldResult> sizeBounds ) {
102103 if (!map)
103104 return false ;
104105 for (unsigned r = 0 ; r < map.getNumResults (); ++r)
105- if (isTiled (map.getResult (r), tileSizes, domainSizes ))
106+ if (isTiled (map.getResult (r), tileSizes, sizeBounds ))
106107 return true ;
107108 return false ;
108109}
@@ -571,19 +572,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
571572 ArrayRef<OpFoldResult> lbs,
572573 ArrayRef<OpFoldResult> ubs,
573574 ArrayRef<OpFoldResult> subShapeSizes,
574- bool omitPartialTileCheck,
575- ArrayRef<int64_t > domainSizes) {
576- SliceParameters sliceParams = computeSliceParameters (
577- builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
578- omitPartialTileCheck, domainSizes);
575+ bool omitPartialTileCheck) {
576+ SliceParameters sliceParams =
577+ computeSliceParameters (builder, loc, valueToTile, tileSizes, map, lbs,
578+ ubs, subShapeSizes, omitPartialTileCheck);
579579 return materializeTiledShape (builder, loc, valueToTile, sliceParams);
580580}
581581
582- SliceParameters computeSliceParameters (
583- OpBuilder &builder, Location loc, Value valueToTile,
584- ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
585- ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
586- bool omitPartialTileCheck, ArrayRef<int64_t > domainSizes) {
582+ SliceParameters
583+ computeSliceParameters (OpBuilder &builder, Location loc, Value valueToTile,
584+ ArrayRef<OpFoldResult> tileSizes, AffineMap map,
585+ ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
586+ ArrayRef<OpFoldResult> subShapeSizes,
587+ bool omitPartialTileCheck) {
587588 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType ());
588589 assert (shapedType && " only shaped types can be tiled" );
589590 ArrayRef<int64_t > shape = shapedType.getShape ();
@@ -600,7 +601,7 @@ SliceParameters computeSliceParameters(
600601 // The offset & size computation below only handles the case when
601602 // the map is monotonically increasing, i.e. the min and max values are
602603 // attained at the lower and upper bounds of the iteration domain.
603- if (!isTiled (m, tileSizes, domainSizes )) {
604+ if (!isTiled (m, tileSizes, ubs )) {
604605 sliceParams.offsets .push_back (builder.getIndexAttr (0 ));
605606 OpFoldResult dim = createFoldedDimOp (builder, loc, valueToTile, r);
606607 sliceParams.sizes .push_back (dim);
@@ -811,7 +812,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
811812
812813 allSliceParams.push_back (computeSliceParameters (
813814 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
814- omitPartialTileCheck, linalgOp. getStaticLoopRanges () ));
815+ omitPartialTileCheck));
815816 }
816817
817818 return allSliceParams;
0 commit comments