@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330330
331331// / Determines whether the tensor::CastOp casts to a more static version of the
332332// / source tensor. This is useful to fold into a producing op and implement
333- // / canonicaliation patterns with the `tensor.cast` op as the root, but producer
334- // / being from different dialects. Returns true when all conditions are met:
333+ // / canonicalization patterns with the `tensor.cast` op as the root, but
334+ // / producer being from different dialects. Returns true when all conditions are
335+ // / met:
335336// / 1. source and result and ranked tensors with same element type and rank.
336337// / 2. the result type has more static information than the source.
337338// /
@@ -773,11 +774,118 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773774    return  success ();
774775  }
775776};
777+ 
778+ // / Propagate static shapes into the operands of a `tensor.concat`.
779+ // /
780+ // / `tensor.concat` requires every operand to match on all dimensions except the
781+ // / concatenation dimension. If one operand is already static in those
782+ // / dimensions, the other operands may safely be refined to that same static
783+ // / shape.
784+ // /
785+ // / Example:
786+ // /
787+ // / ```mlir
788+ // /   // Second operand dim 1 has dynamic shape constrained by dim 1 of first
789+ // /   // operand.
790+ // /   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
791+ // /        tensor<?x12xi32>
792+ // / ```
793+ // / ->
794+ // / ```mlir
795+ // /   %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
796+ // /   %2 = tensor.concat dim(0) %0, %cast :
797+ // /        (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
798+ // / ```
799+ struct  InferConcatOperandTypes  : public  OpRewritePattern <ConcatOp> {
800+   using  OpRewritePattern<ConcatOp>::OpRewritePattern;
801+ 
802+   LogicalResult matchAndRewrite (ConcatOp concatOp,
803+                                 PatternRewriter &rewriter) const  override  {
804+     auto  operandTensorTypes =
805+         llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
806+           return  llvm::cast<RankedTensorType>(type);
807+         });
808+ 
809+     int64_t  dim = concatOp.getDim ();
810+     ArrayRef<int64_t > inferredResultShape =
811+         concatOp.inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
812+ 
813+     //  Find operands for which a more static shape can be inferred.
814+     SmallVector<std::tuple<size_t , RankedTensorType>> refinedTypes;
815+     for  (auto  [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
816+       //  Compute inferred type for operand.
817+       SmallVector<int64_t > inferredOperandShape (inferredResultShape);
818+       inferredOperandShape[dim] = operandType.getDimSize (dim);
819+       auto  inferredOperandType = RankedTensorType::get (
820+           inferredOperandShape, operandType.getElementType ());
821+ 
822+       //  Check if inferred type is more static.
823+       if  (!preservesStaticInformation (inferredOperandType, operandType)) {
824+         refinedTypes.push_back ({operandIdx, inferredOperandType});
825+       }
826+     }
827+ 
828+     if  (refinedTypes.empty ()) {
829+       return  failure ();
830+     }
831+ 
832+     //  Use refined types for operands, insert casts for original type.
833+     SmallVector<Value> newOperands = concatOp.getOperands ();
834+     for  (auto  [operandIdx, refinedType] : refinedTypes) {
835+       newOperands[operandIdx] = rewriter.create <CastOp>(
836+           concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
837+     }
838+     rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
839+                                           dim, newOperands);
840+ 
841+     return  success ();
842+   }
843+ };
844+ 
845+ //  Ensure `tensor.concat`'s result type is at least as static as can be inferred
846+ //  from its operand types.
847+ // /
848+ // / Example:
849+ // / ```mlir
850+ // /   %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
851+ // /   tensor<?x?xi32>
852+ // / ```
853+ // / ->
854+ // / ```mlir
855+ // /   %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
856+ // /   -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
857+ // /   tensor<?x?xi32>
858+ // / ```
859+ struct  InferConcatResultType  : public  OpRewritePattern <ConcatOp> {
860+   using  OpRewritePattern<ConcatOp>::OpRewritePattern;
861+ 
862+   LogicalResult matchAndRewrite (ConcatOp concatOp,
863+                                 PatternRewriter &rewriter) const  override  {
864+     int64_t  dim = concatOp.getDim ();
865+     RankedTensorType inferredResultType =
866+         concatOp.inferResultType (dim, concatOp->getOperandTypes ());
867+ 
868+     //  The result type should be at least as static as inferred result type.
869+     if  (preservesStaticInformation (inferredResultType,
870+                                    concatOp.getResultType ())) {
871+       return  failure ();
872+     }
873+ 
874+     auto  newConcatOp = rewriter.create <ConcatOp>(
875+         concatOp->getLoc (), inferredResultType, dim, concatOp->getOperands ());
876+     rewriter.replaceOpWithNewOp <CastOp>(concatOp, concatOp.getResultType (),
877+                                         newConcatOp);
878+ 
879+     return  llvm::success ();
880+   }
881+ };
776882} //  namespace
777883
778884void  ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
779885                                           MLIRContext *context) {
780-   results.add <SingleInputConcatOp>(context);
886+   results
887+       .add <SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
888+           context);
781889}
782890
783891// ===----------------------------------------------------------------------===//
0 commit comments