@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800800
801801  LogicalResult matchAndRewrite (ConcatOp concatOp,
802802                                PatternRewriter &rewriter) const  override  {
803-     auto  operandTensorTypes =
804-         llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
805-           return  llvm::cast<RankedTensorType>(type);
806-         });
807- 
808803    int64_t  dim = concatOp.getDim ();
809-     ArrayRef< int64_t > inferredResultShape  =
810-         ConcatOp::inferResultType (dim, concatOp->getOperandTypes ()). getShape () ;
804+     RankedTensorType inferredResultType  =
805+         ConcatOp::inferResultType (dim, concatOp->getOperandTypes ());
811806
812807    //  Find operands for which a more static shape can be inferred.
813808    LogicalResult matched = failure ();
814-     for  (auto  [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
809+     //  Inferred operand shapes are identical in every dimension except the
810+     //  concatenation dimension.
811+     SmallVector<int64_t > inferredOperandShape (inferredResultType.getShape ());
812+     for  (auto  [operandIdx, operandType] :
813+          llvm::enumerate (concatOp->getOperandTypes ())) {
815814      //  Compute inferred type for operand.
816-       SmallVector< int64_t >  inferredOperandShape (inferredResultShape); 
817-       inferredOperandShape[dim] =  operandType.getDimSize (dim);
815+       inferredOperandShape[dim] = 
816+           cast<RankedTensorType>( operandType) .getDimSize (dim);
818817      auto  inferredOperandType = RankedTensorType::get (
819-           inferredOperandShape, operandType .getElementType ());
818+           inferredOperandShape, inferredResultType .getElementType ());
820819
821820      //  Check if inferred type is more static.
822821      if  (!preservesStaticInformation (inferredOperandType, operandType)) {
0 commit comments