diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 3170115883e2b..b73da8bb6af59 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat", int64_t getRank() { return ::llvm::cast(getResult().getType()).getRank(); } + + // Method to decompose the operation into a sequence of insert_slices. + FailureOr> decomposeOperation(OpBuilder &builder); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 147120e0e3420..616d4a7d0a0ab 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -615,6 +615,54 @@ LogicalResult ConcatOp::verify() { return success(); } +FailureOr> ConcatOp::decomposeOperation(OpBuilder &builder) { + size_t numInputs = getInputs().size(); + uint64_t concatDim = getDim(); + + SmallVector> inputShapes; + inputShapes.reserve(numInputs); + SmallVector concatOffsets; + concatOffsets.reserve(numInputs); + SmallVector outputShape; + + AffineExpr addExpr = + builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1); + OpFoldResult zero = builder.getIndexAttr(0); + Location loc = getLoc(); + for (auto [index, input] : llvm::enumerate(getInputs())) { + SmallVector inputShape = + tensor::getMixedSizes(builder, input.getLoc(), input); + if (index == 0) { + outputShape = inputShape; + concatOffsets.push_back(zero); + } else { + concatOffsets.push_back(outputShape[concatDim]); + outputShape[concatDim] = affine::makeComposedFoldedAffineApply( + builder, loc, addExpr, + {outputShape[concatDim], inputShape[concatDim]}); + } + inputShapes.emplace_back(std::move(inputShape)); + } + + Value replacement = builder.create( + loc, outputShape, getType().getElementType()); + + int64_t rank = getType().getRank(); + OpFoldResult one = builder.getIndexAttr(1); + SmallVector strides(rank, one); + SmallVector offsets(rank, zero); + for (auto [index, input] : llvm::enumerate(getInputs())) { + offsets[concatDim] = concatOffsets[index]; + auto insertSlice = builder.create( + loc, input, replacement, offsets, inputShapes[index], strides); + replacement = insertSlice.getResult(); + } + if (replacement.getType() != getType()) { + replacement = builder.create(loc, getType(), replacement); + } + return SmallVector{replacement}; +} + LogicalResult ConcatOp::reifyResultShapes(OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp index 7c8403c9609d8..a2a860fcb38ab 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp @@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern { LogicalResult matchAndRewrite(ConcatOp concatOp, PatternRewriter &rewriter) const override { - Location loc = concatOp.getLoc(); - FailureOr dest = - tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); - if (failed(dest)) - return failure(); - - auto empty = dest->getDefiningOp(); - if (!empty) - return failure(); - - int64_t dim = concatOp.getDim(); - Value dimValue = - rewriter.create(loc, rewriter.getIndexAttr(dim)); - - int64_t rank = concatOp.getResultType().getRank(); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - - // Compute the partial sums for the slice offsets. - AffineExpr sum = rewriter.getAffineDimExpr(0); - SmallVector partialSums = {sum}; - SmallVector offsetStrides = {rewriter.getIndexAttr(0)}; - for (auto [idx, input] : - llvm::enumerate(concatOp.getInputs().drop_back())) { - sum = sum + rewriter.getAffineDimExpr(idx + 1); - partialSums.push_back(sum); - offsetStrides.push_back( - rewriter.createOrFold(loc, input, dimValue)); + FailureOr> decomposed = + concatOp.decomposeOperation(rewriter); + if (failed(decomposed)) { + return rewriter.notifyMatchFailure( + concatOp, "failed to get the decomposed insert slices"); } - auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0, - partialSums, rewriter.getContext()); - SmallVector dimOffsets = - affine::makeComposedFoldedMultiResultAffineApply( - rewriter, loc, partialSumMap, offsetStrides); - - // Construct the chain of insert_slice ops into the destination. - Value result = *dest; - for (auto [input, offset] : - llvm::zip_equal(concatOp.getInputs(), dimOffsets)) { - SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, input); - offsets[dim] = offset; - result = rewriter.createOrFold( - loc, input, result, offsets, sizes, strides); - } - - rewriter.replaceOpWithNewOp( - concatOp, concatOp.getResultType(), result); + rewriter.replaceOp(concatOp, decomposed.value()[0]); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir index c0f23b8eddbd5..2b1cb138ecda5 100644 --- a/mlir/test/Dialect/Tensor/decompose-concat.mlir +++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir @@ -1,24 +1,23 @@ -// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s +// RUN: mlir-opt -split-input-file -transform-interpreter -cse --mlir-print-local-scope %s | FileCheck %s func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor) -> tensor { %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor) -> tensor return %0 : tensor } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 4)> // CHECK-LABEL: func @decompose_dynamic_concat( // CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor -// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] -// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor -// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor -// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor -// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor into tensor -// CHECK: return %[[CONCAT]] : tensor +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply affine_map<()[s0] -> (s0 + 4)>()[%[[DIM0]]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32> +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[DIM]], %[[DIM0]]] [1, 1] : tensor into tensor<8x?xf32> +// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor +// CHECK: return %[[CAST]] : tensor func.func @decompose_1d_concat(%arg0 : tensor<1xf32>, %arg1 : tensor<2xf32>, @@ -42,12 +41,14 @@ func.func @decompose_static_concat_dim(%arg0 : tensor<1x?x64xf32>, : (tensor<1x?x64xf32>, tensor<1x?x64xf32>) -> tensor<1x?x128xf32> return %0 : tensor<1x?x128xf32> } -// CHECK-LABEL: func @decompose_static_concat_dim +// CHECK-LABEL: func @decompose_static_concat_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x64xf32>) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32> +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x64xf32> +// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x64xf32> // CHECK: tensor.empty(%[[DIM]]) : tensor<1x?x128xf32> // CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[DIM]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32> -// CHECK: %[[DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x64xf32> // CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, 64] [1, %[[DIM1]], 64] [1, 1, 1] : tensor<1x?x64xf32> into tensor<1x?x128xf32> // CHECK: return %[[CONCAT]] : tensor<1x?x128xf32> @@ -58,19 +59,23 @@ func.func @decompose_dynamic_into_static_concat_dim(%arg0 : tensor<1x?x?xf32>, : (tensor<1x?x?xf32>, tensor<1x?x?xf32>) -> tensor<1x?x128xf32> return %0 : tensor<1x?x128xf32> } -// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim +// CHECK-LABEL: func @decompose_dynamic_into_static_concat_dim( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>) // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[T0_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32> -// CHECK: tensor.empty(%[[T0_DIM1]]) : tensor<1x?x128xf32> -// CHECK: %[[T0_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32> +// CHECK: %[[T0_DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x?xf32> +// CHECK: %[[T0_DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<1x?x?xf32> +// CHECK: %[[T1_DIM1:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<1x?x?xf32> +// CHECK: %[[T1_DIM2:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<1x?x?xf32> +// CHECK: %[[CONCAT_DIM:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 + s1)>()[%[[T0_DIM2]], %[[T1_DIM2]]] +// CHECK: tensor.empty(%[[T0_DIM1]], %[[CONCAT_DIM]]) : tensor<1x?x?xf32> // CHECK: tensor.insert_slice %{{.*}}[0, 0, 0] [1, %[[T0_DIM1]], %[[T0_DIM2]]] [1, 1, 1] -// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32> -// CHECK: %[[T1_DIM1:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?x?xf32> -// CHECK: %[[T1_DIM2:.+]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x?x?xf32> +// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32> // CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[0, 0, %[[T0_DIM2]]] [1, %[[T1_DIM1]], %[[T1_DIM2]]] [1, 1, 1] -// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x128xf32> -// CHECK: return %[[CONCAT]] : tensor<1x?x128xf32> +// CHECK-SAME: tensor<1x?x?xf32> into tensor<1x?x?xf32> +// CHECK: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<1x?x?xf32> to tensor<1x?x128xf32> +// CHECK: return %[[CAST]] : tensor<1x?x128xf32> module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {