@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330330
331331// / Determines whether the tensor::CastOp casts to a more static version of the
332332// / source tensor. This is useful to fold into a producing op and implement
333- // / canonicaliation patterns with the `tensor.cast` op as the root, but producer
334- // / being from different dialects. Returns true when all conditions are met:
333+ // / canonicalization patterns with the `tensor.cast` op as the root, but
334+ // / producer being from different dialects. Returns true when all conditions are
335+ // / met:
335336// / 1. source and result and ranked tensors with same element type and rank.
336337// / 2. the result type has more static information than the source.
337338// /
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773774 return success ();
774775 }
775776};
777+
778+ // / Propagate static shapes into the operands of a `tensor.concat`.
779+ // /
780+ // / `tensor.concat` requires every operand to match on all dimensions except the
781+ // / concatenation dimension. If one operand is already static in those
782+ // / dimensions, the other operands may safely be refined to that same static
783+ // / shape.
784+ // /
785+ // / Example:
786+ // /
787+ // / ```mlir
788+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
789+ // / tensor<?x12xi32>
790+ // / ```
791+ // / ->
792+ // / ```mlir
793+ // / %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
794+ // / %2 = tensor.concat dim(0) %0, %cast :
795+ // / (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
796+ // / ```
797+ struct InferConcatOperandTypes : public OpRewritePattern <ConcatOp> {
798+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
799+
800+ LogicalResult matchAndRewrite (ConcatOp concatOp,
801+ PatternRewriter &rewriter) const override {
802+ auto operandTensorTypes =
803+ llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
804+ return llvm::cast<RankedTensorType>(type);
805+ });
806+
807+ int64_t dim = concatOp.getDim ();
808+ ArrayRef<int64_t > inferredResultShape =
809+ concatOp.inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
810+
811+ // Find operands for which a more static shape can be inferred.
812+ SmallVector<std::tuple<size_t , RankedTensorType>> refinedTypes;
813+ for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
814+ // Compute inferred type for operand.
815+ SmallVector<int64_t > inferredOperandShape (inferredResultShape);
816+ inferredOperandShape[dim] = operandType.getDimSize (dim);
817+ auto inferredOperandType = RankedTensorType::get (
818+ inferredOperandShape, operandType.getElementType ());
819+
820+ // Check if inferred type is more static.
821+ if (!preservesStaticInformation (inferredOperandType, operandType)) {
822+ refinedTypes.push_back ({operandIdx, inferredOperandType});
823+ }
824+ }
825+
826+ if (refinedTypes.empty ()) {
827+ return failure ();
828+ }
829+
830+ // Use refined types for operands, insert casts for original type.
831+ SmallVector<Value> newOperands = concatOp.getOperands ();
832+ for (auto [operandIdx, refinedType] : refinedTypes) {
833+ newOperands[operandIdx] = rewriter.create <CastOp>(
834+ concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
835+ }
836+ rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
837+ dim, newOperands);
838+
839+ return success ();
840+ }
841+ };
842+
843+ // Ensure `tensor.concat`'s result type is at least as static as can be inferred
844+ // from its operand types.
845+ // /
846+ // / Example:
847+ // / ```mlir
848+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
849+ // / tensor<?x?xi32>
850+ // / ```
851+ // / ->
852+ // / ```mlir
853+ // / %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
854+ // / -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
855+ // / tensor<?x?xi32>
856+ // / ```
857+ struct InferConcatResultType : public OpRewritePattern <ConcatOp> {
858+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
859+
860+ LogicalResult matchAndRewrite (ConcatOp concatOp,
861+ PatternRewriter &rewriter) const override {
862+ int64_t dim = concatOp.getDim ();
863+ RankedTensorType inferredResultType =
864+ concatOp.inferResultType (dim, concatOp->getOperandTypes ());
865+
866+ // The result type should be at least as static as inferred result type.
867+ if (preservesStaticInformation (inferredResultType,
868+ concatOp.getResultType ())) {
869+ return failure ();
870+ }
871+
872+ auto newConcatOp = rewriter.create <ConcatOp>(
873+ concatOp->getLoc (), inferredResultType, dim, concatOp->getOperands ());
874+ rewriter.replaceOpWithNewOp <CastOp>(concatOp, concatOp.getResultType (),
875+ newConcatOp);
876+
877+ return llvm::success ();
878+ }
879+ };
776880} // namespace
777881
778882void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
779883 MLIRContext *context) {
780- results.add <SingleInputConcatOp>(context);
884+ results
885+ .add <SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886+ context);
781887}
782888
783889// ===----------------------------------------------------------------------===//
0 commit comments