3333#include " llvm/ADT/STLExtras.h"
3434#include " llvm/ADT/SmallBitVector.h"
3535#include " llvm/ADT/StringRef.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include < algorithm>
3839#include < optional>
@@ -809,7 +810,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
809810 ConcatOp::inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
810811
811812 // Find operands for which a more static shape can be inferred.
812- SmallVector<std::tuple< size_t , RankedTensorType>> refinedTypes ;
813+ LogicalResult matched = failure () ;
813814 for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
814815 // Compute inferred type for operand.
815816 SmallVector<int64_t > inferredOperandShape (inferredResultShape);
@@ -819,24 +820,20 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
819820
820821 // Check if inferred type is more static.
821822 if (!preservesStaticInformation (inferredOperandType, operandType)) {
822- refinedTypes.push_back ({operandIdx, inferredOperandType});
823+ matched = success ();
824+
825+ // Use refined operand type and create cast from original operand.
826+ auto castOp =
827+ rewriter.create <CastOp>(concatOp->getLoc (), inferredOperandType,
828+ concatOp.getOperand (operandIdx));
829+ rewriter.modifyOpInPlace (
830+ concatOp, [=, operandIdx = (size_t )operandIdx] {
831+ concatOp->setOperand (operandIdx, castOp->getResult (0 ));
832+ });
823833 }
824834 }
825835
826- if (refinedTypes.empty ()) {
827- return failure ();
828- }
829-
830- // Use refined types for operands, insert casts for original type.
831- SmallVector<Value> newOperands = concatOp.getOperands ();
832- for (auto [operandIdx, refinedType] : refinedTypes) {
833- newOperands[operandIdx] = rewriter.create <CastOp>(
834- concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
835- }
836- rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
837- dim, newOperands);
838-
839- return success ();
836+ return matched;
840837 }
841838};
842839
0 commit comments