Skip to content

Commit 78da33c

Browse files
committed
Use sizeBounds instead of domainSizes
1 parent 6d0b806 commit 78da33c

File tree

6 files changed

+431
-31
lines changed

6 files changed

+431
-31
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,12 @@ struct SliceParameters {
143143
///
144144
/// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
145145
/// condition check in cases where we statically know that it is unnecessary.
146-
SliceParameters computeSliceParameters(
147-
OpBuilder &builder, Location loc, Value valueToTile,
148-
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
149-
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
150-
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes = {});
146+
SliceParameters
147+
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
148+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
149+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
150+
ArrayRef<OpFoldResult> subShapeSizes,
151+
bool omitPartialTileCheck);
151152

152153
/// Computes SliceParamaters for all `valuesToTile` of the given `linalgOp`,
153154
/// assuming `linalgOp` is being fused into a loop nest. Calls

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ struct LinalgOpTilingInterface
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
117117
ArrayRef<OpFoldResult> sizes) const {
118-
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
119-
// specified could lead to out of bounds accesses.
120118
Location loc = op->getLoc();
121119
LinalgOp linalgOp = cast<LinalgOp>(op);
120+
SmallVector<OpFoldResult> allShapeSizes =
121+
linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
122+
SmallVector<OpFoldResult> sizeBounds =
123+
mlir::affine::makeComposedFoldedMultiResultAffineApply(
124+
b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
122125
SmallVector<Value> valuesToTile = linalgOp->getOperands();
123126
SmallVector<Value> tiledOperands = makeTiledShapes(
124-
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
127+
b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true);
125128
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
126129
llvm::make_filter_range(
127130
tiledOperands,

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,18 @@ namespace {
5656
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757
//
5858
struct TileCheck : public AffineExprVisitor<TileCheck> {
59-
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t> domainSizes)
60-
: tileSizes(tileSizes), domainSizes(domainSizes) {}
59+
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
60+
: tileSizes(tileSizes), sizeBounds(sizeBounds) {}
6161

6262
void visitDimExpr(AffineDimExpr expr) {
6363
unsigned pos = expr.getPosition();
6464

6565
// This dimension is tiled if the tile size is larger than zero and not
6666
// equal to its domain size (if statically known).
6767
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
68-
if (tileSize && !domainSizes.empty()) {
69-
if (domainSizes[pos] == *tileSize) {
68+
if (tileSize && !sizeBounds.empty()) {
69+
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
70+
if (sizeBound && *sizeBound == *tileSize) {
7071
return;
7172
}
7273
}
@@ -82,27 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
8283
}
8384
bool isTiled = false;
8485
ArrayRef<OpFoldResult> tileSizes;
85-
ArrayRef<int64_t> domainSizes;
86+
ArrayRef<OpFoldResult> sizeBounds;
8687
};
8788

8889
} // namespace
8990

9091
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
91-
ArrayRef<int64_t> domainSizes) {
92+
ArrayRef<OpFoldResult> sizeBounds) {
9293
if (!expr)
9394
return false;
94-
TileCheck t(tileSizes, domainSizes);
95+
TileCheck t(tileSizes, sizeBounds);
9596
t.visit(expr);
9697
return t.isTiled;
9798
}
9899

99100
// Checks whether the `map varies with respect to a non-zero `tileSize`.
100101
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
101-
ArrayRef<int64_t> domainSizes) {
102+
ArrayRef<OpFoldResult> sizeBounds) {
102103
if (!map)
103104
return false;
104105
for (unsigned r = 0; r < map.getNumResults(); ++r)
105-
if (isTiled(map.getResult(r), tileSizes, domainSizes))
106+
if (isTiled(map.getResult(r), tileSizes, sizeBounds))
106107
return true;
107108
return false;
108109
}
@@ -571,19 +572,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
571572
ArrayRef<OpFoldResult> lbs,
572573
ArrayRef<OpFoldResult> ubs,
573574
ArrayRef<OpFoldResult> subShapeSizes,
574-
bool omitPartialTileCheck,
575-
ArrayRef<int64_t> domainSizes) {
576-
SliceParameters sliceParams = computeSliceParameters(
577-
builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
578-
omitPartialTileCheck, domainSizes);
575+
bool omitPartialTileCheck) {
576+
SliceParameters sliceParams =
577+
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
578+
ubs, subShapeSizes, omitPartialTileCheck);
579579
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
580580
}
581581

582-
SliceParameters computeSliceParameters(
583-
OpBuilder &builder, Location loc, Value valueToTile,
584-
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
585-
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
586-
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes) {
582+
SliceParameters
583+
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
584+
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
585+
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
586+
ArrayRef<OpFoldResult> subShapeSizes,
587+
bool omitPartialTileCheck) {
587588
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
588589
assert(shapedType && "only shaped types can be tiled");
589590
ArrayRef<int64_t> shape = shapedType.getShape();
@@ -600,7 +601,7 @@ SliceParameters computeSliceParameters(
600601
// The offset & size computation below only handles the case when
601602
// the map is monotonically increasing, i.e. the min and max values are
602603
// attained at the lower and upper bounds of the iteration domain.
603-
if (!isTiled(m, tileSizes, domainSizes)) {
604+
if (!isTiled(m, tileSizes, ubs)) {
604605
sliceParams.offsets.push_back(builder.getIndexAttr(0));
605606
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
606607
sliceParams.sizes.push_back(dim);
@@ -811,7 +812,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
811812

812813
allSliceParams.push_back(computeSliceParameters(
813814
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
814-
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
815+
omitPartialTileCheck));
815816
}
816817

817818
return allSliceParams;

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
984984

985985
// 4a. Clone the operation.
986986
{
987-
auto clonedOp = cast<PartialReductionOpInterface>(rewriter.clone(*op));
987+
auto clonedOp = cast<PartialReductionOpInterface>(
988+
cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
988989

989990
// 4b. Tile the cloned operation.
990991
FailureOr<TilingResult> partialTilingResult =

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,12 @@ module {
555555

556556
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
557557
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
558+
// CHECK: %[[T3:.*]] = linalg.generic {{.*}}
558559
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
559560

560561
%8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
561562
scf.forall.in_parallel {
562-
// CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
563+
// CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
563564
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
564565
}
565566
}

0 commit comments

Comments
 (0)