diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 815806f06b472..6c32476d8656f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include #include @@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { /// Determines whether the tensor::CastOp casts to a more static version of the /// source tensor. This is useful to fold into a producing op and implement -/// canonicaliation patterns with the `tensor.cast` op as the root, but producer -/// being from different dialects. Returns true when all conditions are met: +/// canonicalization patterns with the `tensor.cast` op as the root, but +/// producer being from different dialects. Returns true when all conditions are +/// met: /// 1. source and result and ranked tensors with same element type and rank. /// 2. the result type has more static information than the source. /// @@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern { return success(); } }; + +/// Propagate static shapes into the operands of a `tensor.concat`. +/// +/// `tensor.concat` requires every operand to match on all dimensions except the +/// concatenation dimension. If one operand is already static in those +/// dimensions, the other operands may safely be refined to that same static +/// shape. +/// +/// Example: +/// +/// ```mlir +/// %2 = tensor.concat dim(0) %0, %1: (tensor, tensor) -> +/// tensor +/// ``` +/// -> +/// ```mlir +/// %cast = tensor.cast %1 : tensor to tensor +/// %2 = tensor.concat dim(0) %0, %cast : +/// (tensor, tensor) -> tensor +/// ``` +struct InferConcatOperandTypes : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + auto operandTensorTypes = + llvm::map_range(concatOp->getOperandTypes(), [](Type type) { + return llvm::cast(type); + }); + + int64_t dim = concatOp.getDim(); + ArrayRef inferredResultShape = + ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape(); + + // Find operands for which a more static shape can be inferred. + LogicalResult matched = failure(); + for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) { + // Compute inferred type for operand. + SmallVector inferredOperandShape(inferredResultShape); + inferredOperandShape[dim] = operandType.getDimSize(dim); + auto inferredOperandType = RankedTensorType::get( + inferredOperandShape, operandType.getElementType()); + + // Check if inferred type is more static. + if (!preservesStaticInformation(inferredOperandType, operandType)) { + matched = success(); + + // Use refined operand type and create cast from original operand. + auto castOp = + rewriter.create(concatOp->getLoc(), inferredOperandType, + concatOp.getOperand(operandIdx)); + rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] { + concatOp->setOperand(operandIdx, castOp->getResult(0)); + }); + } + } + + return matched; + } +}; + +// Ensure `tensor.concat`'s result type is at least as static as can be inferred +// from its operand types. +/// +/// Example: +/// ```mlir +/// %2 = tensor.concat dim(0) %0, %1: (tensor, tensor) -> +/// tensor +/// ``` +/// -> +/// ```mlir +/// %2 = tensor.concat dim(0) %0, %cast : (tensor, tensor) +/// -> tensor %cast = tensor.cast %2 : tensor to +/// tensor +/// ``` +struct InferConcatResultType : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConcatOp concatOp, + PatternRewriter &rewriter) const override { + int64_t dim = concatOp.getDim(); + RankedTensorType inferredResultType = + ConcatOp::inferResultType(dim, concatOp->getOperandTypes()); + + // The result type should be at least as static as inferred result type. + if (preservesStaticInformation(inferredResultType, + concatOp.getResultType())) { + return failure(); + } + + auto newConcatOp = rewriter.create( + concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands()); + rewriter.replaceOpWithNewOp(concatOp, concatOp.getResultType(), + newConcatOp); + + return success(); + } +}; } // namespace void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 85bf6fba52aa4..cdcd7f305d2d9 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1 // ----- +// CHECK-LABEL: infer_concat_operand_types +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func.func @infer_concat_operand_types(%arg0: tensor, %arg1: tensor) -> (tensor) { + // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor to tensor + %0 = tensor.concat dim(0) %arg0, %arg1: (tensor, tensor) -> tensor + // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor, tensor) -> tensor + return %0 : tensor + // CHECK-NEXT: return %[[CONCAT]] : tensor +} + +// ----- + +// CHECK-LABEL: infer_concat_return_type +// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor +func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor) -> (tensor) { + %0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor) -> tensor + // CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor) -> tensor + // CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor to tensor + return %0 : tensor + // CHECK-NEXT: return %[[CAST]] : tensor +} + +// ----- + // CHECK-LABEL: func @fold_extract func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { %const_0 = arith.constant 0 : index