@@ -832,6 +832,7 @@ struct InstructionsState {
832832 InstructionsState() = delete;
833833 InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp)
834834 : OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
835+ static InstructionsState invalid() { return {nullptr, nullptr, nullptr}; }
835836};
836837
837838} // end anonymous namespace
@@ -891,20 +892,19 @@ static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI,
891892/// could be vectorized even if its structure is diverse.
892893static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
893894 const TargetLibraryInfo &TLI) {
894- constexpr unsigned BaseIndex = 0;
895895 // Make sure these are all Instructions.
896- if (llvm::any_of (VL, [](Value *V) { return !isa <Instruction>(V); } ))
897- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
896+ if (!all_of (VL, IsaPred <Instruction>))
897+ return InstructionsState::invalid( );
898898
899- bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
900- bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
901- bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
899+ Value *V = VL.front();
900+ bool IsCastOp = isa<CastInst>(V);
901+ bool IsBinOp = isa<BinaryOperator>(V);
902+ bool IsCmpOp = isa<CmpInst>(V);
902903 CmpInst::Predicate BasePred =
903- IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
904- : CmpInst::BAD_ICMP_PREDICATE;
905- unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
904+ IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
905+ unsigned Opcode = cast<Instruction>(V)->getOpcode();
906906 unsigned AltOpcode = Opcode;
907- unsigned AltIndex = BaseIndex ;
907+ unsigned AltIndex = 0 ;
908908
909909 bool SwappedPredsCompatible = [&]() {
910910 if (!IsCmpOp)
@@ -931,14 +931,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
931931 }();
932932 // Check for one alternate opcode from another BinaryOperator.
933933 // TODO - generalize to support all operators (types, calls etc.).
934- auto *IBase = cast<Instruction>(VL[BaseIndex] );
934+ auto *IBase = cast<Instruction>(V );
935935 Intrinsic::ID BaseID = 0;
936936 SmallVector<VFInfo> BaseMappings;
937937 if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
938938 BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
939939 BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
940940 if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
941- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
941+ return InstructionsState::invalid( );
942942 }
943943 for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
944944 auto *I = cast<Instruction>(VL[Cnt]);
@@ -970,7 +970,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
970970 }
971971 }
972972 } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
973- auto *BaseInst = cast<CmpInst>(VL[BaseIndex] );
973+ auto *BaseInst = cast<CmpInst>(V );
974974 Type *Ty0 = BaseInst->getOperand(0)->getType();
975975 Type *Ty1 = Inst->getOperand(0)->getType();
976976 if (Ty0 == Ty1) {
@@ -988,7 +988,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
988988 if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
989989 continue;
990990 auto *AltInst = cast<CmpInst>(VL[AltIndex]);
991- if (AltIndex != BaseIndex ) {
991+ if (AltIndex) {
992992 if (isCmpSameOrSwapped(AltInst, Inst, TLI))
993993 continue;
994994 } else if (BasePred != CurrentPred) {
@@ -1007,27 +1007,28 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
10071007 if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
10081008 if (Gep->getNumOperands() != 2 ||
10091009 Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
1010- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1010+ return InstructionsState::invalid( );
10111011 } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
10121012 if (!isVectorLikeInstWithConstOps(EI))
1013- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1013+ return InstructionsState::invalid( );
10141014 } else if (auto *LI = dyn_cast<LoadInst>(I)) {
10151015 auto *BaseLI = cast<LoadInst>(IBase);
10161016 if (!LI->isSimple() || !BaseLI->isSimple())
1017- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1017+ return InstructionsState::invalid( );
10181018 } else if (auto *Call = dyn_cast<CallInst>(I)) {
10191019 auto *CallBase = cast<CallInst>(IBase);
10201020 if (Call->getCalledFunction() != CallBase->getCalledFunction())
1021- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1022- if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() ||
1023- !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1024- Call->op_begin() + Call->getBundleOperandsEndIndex(),
1025- CallBase->op_begin() +
1026- CallBase->getBundleOperandsStartIndex())))
1027- return InstructionsState(VL[BaseIndex], nullptr, nullptr);
1021+ return InstructionsState::invalid();
1022+ if (Call->hasOperandBundles() &&
1023+ (!CallBase->hasOperandBundles() ||
1024+ !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
1025+ Call->op_begin() + Call->getBundleOperandsEndIndex(),
1026+ CallBase->op_begin() +
1027+ CallBase->getBundleOperandsStartIndex())))
1028+ return InstructionsState::invalid();
10281029 Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI);
10291030 if (ID != BaseID)
1030- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1031+ return InstructionsState::invalid( );
10311032 if (!ID) {
10321033 SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
10331034 if (Mappings.size() != BaseMappings.size() ||
@@ -1037,15 +1038,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
10371038 Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
10381039 Mappings.front().Shape.Parameters !=
10391040 BaseMappings.front().Shape.Parameters)
1040- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1041+ return InstructionsState::invalid( );
10411042 }
10421043 }
10431044 continue;
10441045 }
1045- return InstructionsState(VL[BaseIndex], nullptr, nullptr );
1046+ return InstructionsState::invalid( );
10461047 }
10471048
1048- return InstructionsState(VL[BaseIndex] , cast<Instruction>(VL[BaseIndex] ),
1049+ return InstructionsState(V , cast<Instruction>(V ),
10491050 cast<Instruction>(VL[AltIndex]));
10501051}
10511052
@@ -8019,7 +8020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
80198020 }
80208021
80218022 // Don't handle vectors.
8022- if (!SLPReVec && getValueType(S.OpValue )->isVectorTy()) {
8023+ if (!SLPReVec && getValueType(VL.front() )->isVectorTy()) {
80238024 LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
80248025 newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
80258026 return;
@@ -8088,7 +8089,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
80888089 UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
80898090 bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL);
80908091 bool AreScatterAllGEPSameBlock =
8091- (IsScatterVectorizeUserTE && S.OpValue ->getType()->isPointerTy() &&
8092+ (IsScatterVectorizeUserTE && VL.front() ->getType()->isPointerTy() &&
80928093 VL.size() > 2 &&
80938094 all_of(VL,
80948095 [&BB](Value *V) {
@@ -8104,7 +8105,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
81048105 SortedIndices));
81058106 bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock;
81068107 if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) ||
8107- (isa <InsertElementInst, ExtractValueInst, ExtractElementInst>(
8108+ (isa_and_present <InsertElementInst, ExtractValueInst, ExtractElementInst>(
81088109 S.OpValue) &&
81098110 !all_of(VL, isVectorLikeInstWithConstOps)) ||
81108111 NotProfitableForVectorization(VL)) {
@@ -8161,7 +8162,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
81618162 // Special processing for sorted pointers for ScatterVectorize node with
81628163 // constant indeces only.
81638164 if (!AreAllSameBlock && AreScatterAllGEPSameBlock) {
8164- assert(S.OpValue ->getType()->isPointerTy() &&
8165+ assert(VL.front() ->getType()->isPointerTy() &&
81658166 count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
81668167 "Expected pointers only.");
81678168 // Reset S to make it GetElementPtr kind of node.
0 commit comments