diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index 8b8ab6be99b0d..cb42cfe8159b0 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -93,34 +93,19 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) { Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenRecipe *R) { unsigned Opcode = R->getOpcode(); - switch (Opcode) { - case Instruction::ICmp: - case Instruction::FCmp: - return IntegerType::get(Ctx, 1); - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::SRem: - case Instruction::URem: - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::FDiv: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { + if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) || + Instruction::isBitwiseLogicOp(Opcode)) { Type *ResTy = inferScalarType(R->getOperand(0)); assert(ResTy == inferScalarType(R->getOperand(1)) && "types for both operands must match for binary op"); CachedTypes[R->getOperand(1)] = ResTy; return ResTy; } + + switch (Opcode) { + case Instruction::ICmp: + case Instruction::FCmp: + return IntegerType::get(Ctx, 1); case Instruction::FNeg: case Instruction::Freeze: return inferScalarType(R->getOperand(0)); @@ -157,36 +142,26 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPWidenSelectRecipe *R) { } Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) { - switch (R->getUnderlyingInstr()->getOpcode()) { - case Instruction::Call: { - unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1); - return cast(R->getOperand(CallIdx)->getLiveInIRValue()) - ->getReturnType(); - } - case Instruction::UDiv: - case Instruction::SDiv: - case Instruction::SRem: - case Instruction::URem: - case Instruction::Add: - case Instruction::FAdd: - case Instruction::Sub: - case Instruction::FSub: - case Instruction::Mul: - case Instruction::FMul: - case Instruction::FDiv: - case Instruction::FRem: - case Instruction::Shl: - case Instruction::LShr: - case Instruction::AShr: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { + unsigned Opcode = R->getUnderlyingInstr()->getOpcode(); + + if (Instruction::isBinaryOp(Opcode) || Instruction::isShift(Opcode) || + Instruction::isBitwiseLogicOp(Opcode)) { Type *ResTy = inferScalarType(R->getOperand(0)); assert(ResTy == inferScalarType(R->getOperand(1)) && "inferred types for operands of binary op don't match"); CachedTypes[R->getOperand(1)] = ResTy; return ResTy; } + + if (Instruction::isCast(Opcode)) + return R->getUnderlyingInstr()->getType(); + + switch (Opcode) { + case Instruction::Call: { + unsigned CallIdx = R->getNumOperands() - (R->isPredicated() ? 2 : 1); + return cast(R->getOperand(CallIdx)->getLiveInIRValue()) + ->getReturnType(); + } case Instruction::Select: { Type *ResTy = inferScalarType(R->getOperand(1)); assert(ResTy == inferScalarType(R->getOperand(2)) && @@ -197,21 +172,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPReplicateRecipe *R) { case Instruction::ICmp: case Instruction::FCmp: return IntegerType::get(Ctx, 1); - case Instruction::AddrSpaceCast: case Instruction::Alloca: - case Instruction::BitCast: - case Instruction::Trunc: - case Instruction::SExt: - case Instruction::ZExt: - case Instruction::FPExt: - case Instruction::FPTrunc: case Instruction::ExtractValue: - case Instruction::SIToFP: - case Instruction::UIToFP: - case Instruction::FPToSI: - case Instruction::FPToUI: - case Instruction::PtrToInt: - case Instruction::IntToPtr: return R->getUnderlyingInstr()->getType(); case Instruction::Freeze: case Instruction::FNeg: