Skip to content

Commit dd9d616

Browse files
[Flow] Add patterns to convert from tensor.concat to flow.tensor.update. (iree-org#19126)
These are in preparation to delay to decomposition of `tensor.concat` into `tensor.insert_slice`s. This patch just adds the patterns to lower a `tensor.concat` along the outer dimension to `flow.tensor.update`. Future changes will delay the decomposition of `tensor.concat` to allow for non-outer dimension concatenation to be conveted into `tensor.insert_slice`s before dispatch formation with the `tensor.insert_slice` fused into its producers. Towards iree-org#19092 --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ef241f9 commit dd9d616

File tree

6 files changed

+117
-8
lines changed

6 files changed

+117
-8
lines changed

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ iree_compiler_cc_library(
2525
deps = [
2626
"//compiler/src/iree/compiler/Dialect/Flow/IR",
2727
"@llvm-project//llvm:Support",
28+
"@llvm-project//mlir:AffineDialect",
2829
"@llvm-project//mlir:Analysis",
2930
"@llvm-project//mlir:ArithDialect",
3031
"@llvm-project//mlir:ArithUtils",

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ iree_cc_library(
2121
"Utils.cpp"
2222
DEPS
2323
LLVMSupport
24+
MLIRAffineDialect
2425
MLIRAnalysis
2526
MLIRArithDialect
2627
MLIRArithUtils

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
1010
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
1111
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1213
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/Arith/Utils/Utils.h"
1415
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -174,6 +175,74 @@ struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
174175
}
175176
};
176177

178+
struct ConvertTensorConcatPattern : public OpRewritePattern<tensor::ConcatOp> {
179+
using OpRewritePattern<tensor::ConcatOp>::OpRewritePattern;
180+
181+
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
182+
PatternRewriter &rewriter) const override {
183+
if (concatOp->getParentOfType<IREE::Flow::DispatchRegionOp>() ||
184+
concatOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
185+
return failure();
186+
}
187+
if (concatOp.getDim() != 0) {
188+
return rewriter.notifyMatchFailure(
189+
concatOp, "only outer-dim concat lowering supported");
190+
}
191+
assert(cast<RankedTensorType>(concatOp.getInputs().front().getType())
192+
.getRank() != 0 &&
193+
"concat cannot be of zero-rank tensors");
194+
195+
Location loc = concatOp.getLoc();
196+
SmallVector<SmallVector<OpFoldResult>> inputShapes;
197+
inputShapes.reserve(concatOp.getInputs().size());
198+
// Note the output shape is computed directly without using
199+
// `reifyResultShapes` since we need the `inputShapes` anyway and using the
200+
// method would create duplicate `tensor.dim` operations.
201+
SmallVector<OpFoldResult> outputShape;
202+
AffineExpr addExpr =
203+
rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1);
204+
SmallVector<OpFoldResult> concatOffsets;
205+
concatOffsets.reserve(concatOp.getInputs().size());
206+
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
207+
SmallVector<OpFoldResult> inputShape =
208+
tensor::getMixedSizes(rewriter, input.getLoc(), input);
209+
if (index == 0) {
210+
outputShape = inputShape;
211+
concatOffsets.push_back(rewriter.getIndexAttr(0));
212+
} else {
213+
concatOffsets.push_back(outputShape[0]);
214+
outputShape[0] = affine::makeComposedFoldedAffineApply(
215+
rewriter, loc, addExpr, {outputShape[0], inputShape[0]});
216+
}
217+
inputShapes.emplace_back(std::move(inputShape));
218+
}
219+
220+
Value replacement = rewriter.create<tensor::EmptyOp>(
221+
loc, outputShape, concatOp.getType().getElementType());
222+
223+
SmallVector<int64_t> resultStaticDims;
224+
SmallVector<Value> resultDynamicDims;
225+
dispatchIndexOpFoldResults(outputShape, resultDynamicDims,
226+
resultStaticDims);
227+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
228+
// Generate the `flow.tensor.update` operations for the concat.
229+
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
230+
SmallVector<int64_t> inputStaticShape;
231+
SmallVector<Value> inputDynamicShape;
232+
dispatchIndexOpFoldResults(inputShapes[index], inputDynamicShape,
233+
inputStaticShape);
234+
SmallVector<Value> offsets(inputStaticShape.size(), zero);
235+
offsets[0] =
236+
getValueOrCreateConstantIndexOp(rewriter, loc, concatOffsets[index]);
237+
replacement = rewriter.create<IREE::Flow::TensorUpdateOp>(
238+
loc, replacement.getType(), replacement, resultDynamicDims, offsets,
239+
input, inputDynamicShape);
240+
}
241+
rewriter.replaceOp(concatOp, replacement);
242+
return success();
243+
}
244+
};
245+
177246
struct ConvertTensorFromElementsPattern
178247
: public OpRewritePattern<tensor::FromElementsOp> {
179248
using OpRewritePattern<tensor::FromElementsOp>::OpRewritePattern;
@@ -316,14 +385,14 @@ struct ConvertTensorReshapePattern : public OpRewritePattern<TensorReshapeOp> {
316385

317386
void populateTensorToFlowConversionPatterns(MLIRContext *context,
318387
RewritePatternSet &patterns) {
319-
patterns
320-
.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
321-
ConvertTensorCastPattern, ConvertTensorExtractPattern,
322-
ConvertTensorExtractSlicePattern, ConvertTensorInsertSlicePattern,
323-
ConvertTensorInsertPattern, ConvertTensorFromElementsPattern,
324-
ConvertTensorDialectReshapeOpPattern,
325-
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
326-
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
388+
patterns.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
389+
ConvertTensorCastPattern, ConvertTensorConcatPattern,
390+
ConvertTensorExtractPattern, ConvertTensorExtractSlicePattern,
391+
ConvertTensorInsertSlicePattern, ConvertTensorInsertPattern,
392+
ConvertTensorFromElementsPattern,
393+
ConvertTensorDialectReshapeOpPattern,
394+
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
395+
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
327396
}
328397

329398
} // namespace mlir::iree_compiler::IREE::Flow

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
[
1919
"bitcast.mlir",
2020
"cast.mlir",
21+
"concat.mlir",
2122
"extract.mlir",
2223
"extract_slice.mlir",
2324
"fill.mlir",

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
SRCS
1717
"bitcast.mlir"
1818
"cast.mlir"
19+
"concat.mlir"
1920
"extract.mlir"
2021
"extract_slice.mlir"
2122
"fill.mlir"
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: iree-opt --iree-flow-convert-to-flow --split-input-file --mlir-print-local-scope %s | FileCheck %s
2+
3+
func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
4+
%0 = tensor.concat dim(0) %arg0, %arg1, %arg2 : (tensor<2x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
5+
return %0 : tensor<?x?xf32>
6+
}
7+
// CHECK-LABEL: func @mixed_concat
8+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x?xf32>
9+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
10+
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x?xf32>
11+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
12+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
13+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
14+
// CHECK-DAG: %[[ARG0_D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
15+
// CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
16+
// CHECK-DAG: %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
17+
// CHECK: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%[[ARG1_D0]]]
18+
// CHECK: %[[ARG2_D1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
19+
// CHECK: %[[RESULT_D0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 6)>()[%[[ARG1_D0]]]
20+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[RESULT_D0]], %[[ARG0_D1]])
21+
// CHECK: %[[UPDATE0:.+]] = flow.tensor.update %[[ARG0]], %[[EMPTY]][%[[C0]], %[[C0]]]
22+
// CHECK-SAME: : tensor<2x?xf32>{%[[ARG0_D1]]} -> %[[EMPTY]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
23+
// CHECK: %[[UPDATE1:.+]] = flow.tensor.update %[[ARG1]], %[[UPDATE0]][%[[C2]], %[[C0]]]
24+
// CHECK-SAME: : tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
25+
// CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]]
26+
// CHECK-SAME: : tensor<4x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE1]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
27+
28+
// -----
29+
30+
func.func @dont_lower_non_outer_dim_concat(%arg0: tensor<4x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
31+
%0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
32+
return %0 : tensor<?x?xf32>
33+
}
34+
// CHECK-LABEL: func @dont_lower_non_outer_dim_concat
35+
// CHECK: %[[CONCAT:.+]] = tensor.concat
36+
// CHECK: return %[[CONCAT]]

0 commit comments

Comments
 (0)