Skip to content

Commit 789768a

Browse files
Max191Max Dawkins
authored andcommitted
[mlir] Clamp UnPackOp tiling sizes from operand tile
Signed-off-by: Max Dawkins <[email protected]>
1 parent b6bd747 commit 789768a

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,13 +1996,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19961996
candidateSliceOp, "containingOp's result yield with stride");
19971997
}
19981998

1999-
// 10. Try to get iter domain position from input position.
1999+
// 10. Try to get iter domain position from input position. Use
2000+
// clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
2001+
// may require index computation based on the result size. The sizes and
2002+
// offsets should be the same either way, but using tiledConsumerOp could
2003+
// lead to some chained unnecessary extra index computation.
20002004
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2001-
if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
2005+
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
20022006
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
20032007
iterDomainSizes))) {
20042008
return rewriter.notifyMatchFailure(
2005-
tiledConsumerOp,
2009+
clonedConsumerOp,
20062010
"can't get iter domain position from input position");
20072011
}
20082012

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Tensor/Utils/Utils.h"
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1920
#include "mlir/Interfaces/TilingInterface.h"
2021
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2122

@@ -621,6 +622,12 @@ struct UnPackOpTiling
621622
SmallVectorImpl<OpFoldResult> &resultOffsets,
622623
SmallVectorImpl<OpFoldResult> &resultSizes) const {
623624
auto unPackOp = cast<UnPackOp>(op);
625+
// If the operand tile is the dest, then no adjustment is needed.
626+
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
627+
resultOffsets = llvm::to_vector(offsets);
628+
resultSizes = llvm::to_vector(sizes);
629+
return success();
630+
}
624631
Location loc = unPackOp.getLoc();
625632

626633
int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -629,6 +636,11 @@ struct UnPackOpTiling
629636
// The tiling is applied on interchanged dimensions. We have to undo the
630637
// interchange to map sizes and offsets to the original input.
631638
int64_t outputRank = unPackOp.getDestRank();
639+
ReifiedRankedShapedTypeDims reifiedReturnShapes;
640+
if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes))) {
641+
return failure();
642+
}
643+
SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
632644
SmallVector<OpFoldResult> origOffsets(destOffsets);
633645
SmallVector<OpFoldResult> origSizes(destSizes);
634646
applyPermToRange(origOffsets, origSizes,
@@ -640,18 +652,21 @@ struct UnPackOpTiling
640652
for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
641653
using AV = affine::AffineValueExpr;
642654
affine::AffineBuilder ab(b, loc);
643-
AffineExpr dim0, dim1, sym;
655+
AffineExpr dim0, dim1, sym0;
644656
bindDims(b.getContext(), dim0, dim1);
645-
bindSymbols(b.getContext(), sym);
657+
bindSymbols(b.getContext(), sym0);
646658
if (dimAndTileMapping.count(dim)) {
647659
// If the data dimension is tiled, the i-th index is the product of
648660
// offset_i and tile_i, and the i-th size is the product of sizes_i and
649-
// tile_i.
661+
// tile_i. The sizes must be clamped to the sizes of the unpack result.
650662
auto avOffset = AV(dim0).bind(origOffsets[dim]);
651663
auto avSize = AV(dim0).bind(origSizes[dim]);
652-
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
664+
auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
665+
auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
653666
resultOffsets.push_back(ab.mul(avOffset, avTileSize));
654-
resultSizes.push_back(ab.mul(avSize, avTileSize));
667+
auto avResultOffset = AV(dim1).bind(resultOffsets.back());
668+
resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
669+
ab.sub(avResultSize, avResultOffset)}));
655670
} else {
656671
resultOffsets.push_back(origOffsets[dim]);
657672
resultSizes.push_back(origSizes[dim]);

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ module {
265265
%c4 = arith.constant 4 : index
266266
%c64 = arith.constant 64 : index
267267
%c0 = arith.constant 0 : index
268-
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
268+
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
269269
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
270270
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
271271
^bb0(%in: f32, %in_16: f32, %out: f32):
@@ -292,26 +292,28 @@ module attributes {transform.with_named_sequence} {
292292
transform.yield
293293
}
294294
}
295-
// CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
295+
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
296+
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
296297
// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
297298
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
298299
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
299300
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
300301
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
301-
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
302+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
302303
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
303304
// CHECK-SAME: {
304305
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
305306
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
306307
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
307-
// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
308-
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
308+
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
309+
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
310+
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
309311
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
310312
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
311313
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
312314
// CHECK: scf.forall.in_parallel {
313315
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
314-
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
316+
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
315317
// CHECK: }
316318
// CHECK: }
317319
// CHECK: return %[[FINAL_RESULT]]#1 :

0 commit comments

Comments
 (0)