Skip to content

Commit ca3622a

Browse files
[Flow] Add patterns to convert from tensor.concat to flow.tensor.update.
These are in preparation to delay to decomposition of `tensor.concat` into `tensor.insert_slice`s. This patch just adds the patterns to lower a `tensor.concat` along the outer dimension to `flow.tensor.update`. Future changes will delay the decomposition of `tensor.concat` to allow for non-outer dimension concatenation to be conveted into `tensor.insert_slice`s before dispatch formation with the `tensor.insert_slice` fused into its producers. Towards iree-org#19092 Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ab35e1b commit ca3622a

File tree

4 files changed

+106
-8
lines changed

4 files changed

+106
-8
lines changed

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Patterns.cpp

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
1010
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
1111
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
12+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1213
#include "mlir/Dialect/Arith/IR/Arith.h"
1314
#include "mlir/Dialect/Arith/Utils/Utils.h"
1415
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -174,6 +175,75 @@ struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
174175
}
175176
};
176177

178+
struct ConvertTensorConcatPattern : public OpRewritePattern<tensor::ConcatOp> {
179+
using OpRewritePattern<tensor::ConcatOp>::OpRewritePattern;
180+
181+
LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
182+
PatternRewriter &rewriter) const override {
183+
if (concatOp->getParentOfType<IREE::Flow::DispatchRegionOp>() ||
184+
concatOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
185+
return failure();
186+
}
187+
if (concatOp.getDim() != 0) {
188+
return rewriter.notifyMatchFailure(
189+
concatOp, "only outer-dim concat lowering supported");
190+
}
191+
if (cast<RankedTensorType>(concatOp.getInputs().front().getType())
192+
.getRank() == 0) {
193+
// This should be handled here, but not sure what concat operation does
194+
// when inptus are of rank 0.
195+
return rewriter.notifyMatchFailure(
196+
concatOp, "unhandled concat of zero-rank tensors");
197+
}
198+
199+
Location loc = concatOp.getLoc();
200+
SmallVector<SmallVector<OpFoldResult>> inputShapes;
201+
inputShapes.reserve(concatOp.getInputs().size());
202+
SmallVector<OpFoldResult> outputShape;
203+
AffineExpr addExpr =
204+
rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1);
205+
SmallVector<OpFoldResult> concatOffsets;
206+
concatOffsets.reserve(concatOp.getInputs().size());
207+
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
208+
SmallVector<OpFoldResult> inputShape =
209+
tensor::getMixedSizes(rewriter, input.getLoc(), input);
210+
if (index == 0) {
211+
outputShape = inputShape;
212+
concatOffsets.push_back(rewriter.getIndexAttr(0));
213+
} else {
214+
concatOffsets.push_back(outputShape[0]);
215+
outputShape[0] = affine::makeComposedFoldedAffineApply(
216+
rewriter, loc, addExpr, {outputShape[0], inputShape[0]});
217+
}
218+
inputShapes.emplace_back(std::move(inputShape));
219+
}
220+
221+
Value replacement = rewriter.create<tensor::EmptyOp>(
222+
loc, outputShape, concatOp.getType().getElementType());
223+
224+
SmallVector<int64_t> resultStaticDims;
225+
SmallVector<Value> resultDynamicDims;
226+
dispatchIndexOpFoldResults(outputShape, resultDynamicDims,
227+
resultStaticDims);
228+
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
229+
// Generate the `flow.tensor.update` operations for the concat.
230+
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
231+
SmallVector<int64_t> inputStaticShape;
232+
SmallVector<Value> inputDynamicShape;
233+
dispatchIndexOpFoldResults(inputShapes[index], inputDynamicShape,
234+
inputStaticShape);
235+
SmallVector<Value> offsets(inputStaticShape.size(), zero);
236+
offsets[0] =
237+
getValueOrCreateConstantIndexOp(rewriter, loc, concatOffsets[index]);
238+
replacement = rewriter.create<IREE::Flow::TensorUpdateOp>(
239+
loc, replacement.getType(), replacement, resultDynamicDims, offsets,
240+
input, inputDynamicShape);
241+
}
242+
rewriter.replaceOp(concatOp, replacement);
243+
return success();
244+
}
245+
};
246+
177247
struct ConvertTensorFromElementsPattern
178248
: public OpRewritePattern<tensor::FromElementsOp> {
179249
using OpRewritePattern<tensor::FromElementsOp>::OpRewritePattern;
@@ -316,14 +386,14 @@ struct ConvertTensorReshapePattern : public OpRewritePattern<TensorReshapeOp> {
316386

317387
void populateTensorToFlowConversionPatterns(MLIRContext *context,
318388
RewritePatternSet &patterns) {
319-
patterns
320-
.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
321-
ConvertTensorCastPattern, ConvertTensorExtractPattern,
322-
ConvertTensorExtractSlicePattern, ConvertTensorInsertSlicePattern,
323-
ConvertTensorInsertPattern, ConvertTensorFromElementsPattern,
324-
ConvertTensorDialectReshapeOpPattern,
325-
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
326-
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
389+
patterns.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
390+
ConvertTensorCastPattern, ConvertTensorConcatPattern,
391+
ConvertTensorExtractPattern, ConvertTensorExtractSlicePattern,
392+
ConvertTensorInsertSlicePattern, ConvertTensorInsertPattern,
393+
ConvertTensorFromElementsPattern,
394+
ConvertTensorDialectReshapeOpPattern,
395+
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
396+
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
327397
}
328398

329399
} // namespace mlir::iree_compiler::IREE::Flow

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ iree_lit_test_suite(
1818
[
1919
"bitcast.mlir",
2020
"cast.mlir",
21+
"concat.mlir",
2122
"extract.mlir",
2223
"extract_slice.mlir",
2324
"fill.mlir",

compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ iree_lit_test_suite(
1616
SRCS
1717
"bitcast.mlir"
1818
"cast.mlir"
19+
"concat.mlir"
1920
"extract.mlir"
2021
"extract_slice.mlir"
2122
"fill.mlir"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: iree-opt --iree-flow-convert-to-flow --split-input-file --mlir-print-local-scope %s | FileCheck %s
2+
3+
func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
4+
%0 = tensor.concat dim(0) %arg0, %arg1, %arg2 : (tensor<2x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
5+
return %0 : tensor<?x?xf32>
6+
}
7+
// CHECK-LABEL: func @mixed_concat
8+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x?xf32>
9+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
10+
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x?xf32>
11+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
12+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
13+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
14+
// CHECK-DAG: %[[ARG0_D1]] = tensor.dim %[[ARG0]], %[[C1]]
15+
// CHECK-DAG: %[[ARG1_D0]] = tensor.dim %[[ARG1]], %[[C0]]
16+
// CHECK-DAG: %[[ARG1_D1]] = tensor.dim %[[ARG1]], %[[C1]]
17+
// CHECK: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%[[ARG1_D0]]]
18+
// CHECK: %[[ARG2_D1]] = tensor.dim %[[ARG2]], %[[C1]]
19+
// CHECK: %[[RESULT_D0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 6)>()[%[[ARG1_D0]]]
20+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[RESULT_D0]], %[[ARG0_D1]])
21+
// CHECK: %[[UPDATE0:.+]] = flow.tensor.update %[[ARG0]], %[[EMPTY]][%[[C0]], %[[C0]]]
22+
// CHECK-SAME: : tensor<2x?xf32>{%[[ARG0_D1]]} -> %[[EMPTY]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
23+
// CHECK: %[[UPDATE1:.+]] = flow.tensor.update %[[ARG1]], %[[UPDATE0]][%[[C2]], %[[C0]]]
24+
// CHECK-SAME: : tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
25+
// CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]]
26+
// CHECK-SAME: : tensor<2x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE2]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}

0 commit comments

Comments
 (0)