Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1996,13 +1996,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
candidateSliceOp, "containingOp's result yield with stride");
}

// 10. Try to get iter domain position from input position.
// 10. Try to get iter domain position from input position. Use
// clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
// may require index computation based on the result size. The sizes and
// offsets should be the same either way, but using tiledConsumerOp could
// lead to some chained unnecessary extra index computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
iterDomainSizes))) {
return rewriter.notifyMatchFailure(
tiledConsumerOp,
clonedConsumerOp,
"can't get iter domain position from input position");
}

Expand Down
24 changes: 19 additions & 5 deletions mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

Expand Down Expand Up @@ -621,6 +622,12 @@ struct UnPackOpTiling
SmallVectorImpl<OpFoldResult> &resultOffsets,
SmallVectorImpl<OpFoldResult> &resultSizes) const {
auto unPackOp = cast<UnPackOp>(op);
// If the operand tile is the dest, then no adjustment is needed.
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
}
Location loc = unPackOp.getLoc();

int64_t numTiles = unPackOp.getInnerDimsPos().size();
Expand All @@ -629,6 +636,10 @@ struct UnPackOpTiling
// The tiling is applied on interchanged dimensions. We have to undo the
// interchange to map sizes and offsets to the original input.
int64_t outputRank = unPackOp.getDestRank();
ReifiedRankedShapedTypeDims reifiedReturnShapes;
if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
return failure();
SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
SmallVector<OpFoldResult> origOffsets(destOffsets);
SmallVector<OpFoldResult> origSizes(destSizes);
applyPermToRange(origOffsets, origSizes,
Expand All @@ -640,18 +651,21 @@ struct UnPackOpTiling
for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
using AV = affine::AffineValueExpr;
affine::AffineBuilder ab(b, loc);
AffineExpr dim0, dim1, sym;
AffineExpr dim0, dim1, sym0;
bindDims(b.getContext(), dim0, dim1);
bindSymbols(b.getContext(), sym);
bindSymbols(b.getContext(), sym0);
if (dimAndTileMapping.count(dim)) {
// If the data dimension is tiled, the i-th index is the product of
// offset_i and tile_i, and the i-th size is the product of sizes_i and
// tile_i.
// tile_i. The sizes must be clamped to the sizes of the unpack result.
auto avOffset = AV(dim0).bind(origOffsets[dim]);
auto avSize = AV(dim0).bind(origSizes[dim]);
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
resultOffsets.push_back(ab.mul(avOffset, avTileSize));
resultSizes.push_back(ab.mul(avSize, avTileSize));
auto avResultOffset = AV(dim1).bind(resultOffsets.back());
resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
ab.sub(avResultSize, avResultOffset)}));
} else {
resultOffsets.push_back(origOffsets[dim]);
resultSizes.push_back(origSizes[dim]);
Expand Down
75 changes: 69 additions & 6 deletions mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ module {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
^bb0(%in: f32, %in_16: f32, %out: f32):
Expand All @@ -292,26 +292,89 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
// CHECK-SAME: {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
%c4 = arith.constant 4 : index
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
^bb0(%in: f32, %in_16: f32, %out: f32):
%13 = arith.mulf %in, %in_16 : f32
%14 = arith.addf %out, %13 : f32
linalg.yield %14 : f32
} -> tensor<32x32xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
}
}
%output = tensor.empty() : tensor<2047xf32>
%unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
return %unpack : tensor<2047xf32>
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
: (!transform.any_op) -> !transform.any_op
%a, %b = transform.test.fuse_consumer %slice_op
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
// CHECK-SAME: {
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
// CHECK: scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
// CHECK: }
// CHECK: }
// CHECK: return %[[FINAL_RESULT]]#1 :
Expand Down
Loading