@@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
6969
7070static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
7171 const TargetTransformInfo &TTI, const DataLayout &DL,
72- DomTreeUpdater *DTU);
72+ bool HasBranchDivergence, DomTreeUpdater *DTU);
7373static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
7474 const TargetTransformInfo &TTI,
75- const DataLayout &DL, DomTreeUpdater *DTU);
75+ const DataLayout &DL, bool HasBranchDivergence,
76+ DomTreeUpdater *DTU);
7677
7778char ScalarizeMaskedMemIntrinLegacyPass::ID = 0 ;
7879
@@ -141,8 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
141142// %10 = extractelement <16 x i1> %mask, i32 2
142143// br i1 %10, label %cond.load4, label %else5
143144//
144- static void scalarizeMaskedLoad (const DataLayout &DL, CallInst *CI,
145- DomTreeUpdater *DTU, bool &ModifiedDT) {
145+ static void scalarizeMaskedLoad (const DataLayout &DL, bool HasBranchDivergence,
146+ CallInst *CI, DomTreeUpdater *DTU,
147+ bool &ModifiedDT) {
146148 Value *Ptr = CI->getArgOperand (0 );
147149 Value *Alignment = CI->getArgOperand (1 );
148150 Value *Mask = CI->getArgOperand (2 );
@@ -221,25 +223,26 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
221223 return ;
222224 }
223225 // If the mask is not v1i1, use scalar bit test operations. This generates
224- // better results on X86 at least.
225- // Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
226- // - what's a good way to detect this?
227- Value *SclrMask;
228- if (VectorWidth != 1 ) {
226+ // better results on X86 at least. However, don't do this on GPUs and other
227+ // machines with divergence, as there each i1 needs a vector register.
228+ Value *SclrMask = nullptr ;
229+ if (VectorWidth != 1 && !HasBranchDivergence) {
229230 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
230231 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
231232 }
232233
233234 for (unsigned Idx = 0 ; Idx < VectorWidth; ++Idx) {
234235 // Fill the "else" block, created in the previous iteration
235236 //
236- // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
237- // %mask_1 = and i16 %scalar_mask, i32 1 << Idx
238- // %cond = icmp ne i16 %mask_1, 0
239- // br i1 %mask_1, label %cond.load, label %else
237+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238+ // %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239+ // %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240240 //
241+ // On GPUs, use
242+ // %cond = extrectelement %mask, Idx
243+ // instead
241244 Value *Predicate;
242- if (VectorWidth != 1 ) {
245+ if (SclrMask != nullptr ) {
243246 Value *Mask = Builder.getInt (APInt::getOneBitSet (
244247 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
245248 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -312,8 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
312315// store i32 %6, i32* %7
313316// br label %else2
314317// . . .
315- static void scalarizeMaskedStore (const DataLayout &DL, CallInst *CI,
316- DomTreeUpdater *DTU, bool &ModifiedDT) {
318+ static void scalarizeMaskedStore (const DataLayout &DL, bool HasBranchDivergence,
319+ CallInst *CI, DomTreeUpdater *DTU,
320+ bool &ModifiedDT) {
317321 Value *Src = CI->getArgOperand (0 );
318322 Value *Ptr = CI->getArgOperand (1 );
319323 Value *Alignment = CI->getArgOperand (2 );
@@ -378,10 +382,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
378382 }
379383
380384 // If the mask is not v1i1, use scalar bit test operations. This generates
381- // better results on X86 at least.
382-
383- Value *SclrMask;
384- if (VectorWidth != 1 ) {
385+ // better results on X86 at least. However, don't do this on GPUs or other
386+ // machines with branch divergence, as there each i1 takes up a register.
387+ Value *SclrMask = nullptr ;
388+ if (VectorWidth != 1 && !HasBranchDivergence ) {
385389 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
386390 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
387391 }
@@ -393,8 +397,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
393397 // %cond = icmp ne i16 %mask_1, 0
394398 // br i1 %mask_1, label %cond.store, label %else
395399 //
400+ // On GPUs, use
401+ // %cond = extrectelement %mask, Idx
402+ // instead
396403 Value *Predicate;
397- if (VectorWidth != 1 ) {
404+ if (SclrMask != nullptr ) {
398405 Value *Mask = Builder.getInt (APInt::getOneBitSet (
399406 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
400407 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -461,7 +468,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
461468// . . .
462469// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
463470// ret <16 x i32> %Result
464- static void scalarizeMaskedGather (const DataLayout &DL, CallInst *CI,
471+ static void scalarizeMaskedGather (const DataLayout &DL,
472+ bool HasBranchDivergence, CallInst *CI,
465473 DomTreeUpdater *DTU, bool &ModifiedDT) {
466474 Value *Ptrs = CI->getArgOperand (0 );
467475 Value *Alignment = CI->getArgOperand (1 );
@@ -500,9 +508,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
500508 }
501509
502510 // If the mask is not v1i1, use scalar bit test operations. This generates
503- // better results on X86 at least.
504- Value *SclrMask;
505- if (VectorWidth != 1 ) {
511+ // better results on X86 at least. However, don't do this on GPUs or other
512+ // machines with branch divergence, as there, each i1 takes up a register.
513+ Value *SclrMask = nullptr ;
514+ if (VectorWidth != 1 && !HasBranchDivergence) {
506515 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
507516 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
508517 }
@@ -514,9 +523,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
514523 // %cond = icmp ne i16 %mask_1, 0
515524 // br i1 %Mask1, label %cond.load, label %else
516525 //
526+ // On GPUs, use
527+ // %cond = extrectelement %mask, Idx
528+ // instead
517529
518530 Value *Predicate;
519- if (VectorWidth != 1 ) {
531+ if (SclrMask != nullptr ) {
520532 Value *Mask = Builder.getInt (APInt::getOneBitSet (
521533 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
522534 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -591,7 +603,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
591603// store i32 %Elt1, i32* %Ptr1, align 4
592604// br label %else2
593605// . . .
594- static void scalarizeMaskedScatter (const DataLayout &DL, CallInst *CI,
606+ static void scalarizeMaskedScatter (const DataLayout &DL,
607+ bool HasBranchDivergence, CallInst *CI,
595608 DomTreeUpdater *DTU, bool &ModifiedDT) {
596609 Value *Src = CI->getArgOperand (0 );
597610 Value *Ptrs = CI->getArgOperand (1 );
@@ -629,8 +642,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
629642
630643 // If the mask is not v1i1, use scalar bit test operations. This generates
631644 // better results on X86 at least.
632- Value *SclrMask;
633- if (VectorWidth != 1 ) {
645+ Value *SclrMask = nullptr ;
646+ if (VectorWidth != 1 && !HasBranchDivergence ) {
634647 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
635648 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
636649 }
@@ -642,8 +655,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
642655 // %cond = icmp ne i16 %mask_1, 0
643656 // br i1 %Mask1, label %cond.store, label %else
644657 //
658+ // On GPUs, use
659+ // %cond = extrectelement %mask, Idx
660+ // instead
645661 Value *Predicate;
646- if (VectorWidth != 1 ) {
662+ if (SclrMask != nullptr ) {
647663 Value *Mask = Builder.getInt (APInt::getOneBitSet (
648664 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
649665 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -681,7 +697,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
681697 ModifiedDT = true ;
682698}
683699
684- static void scalarizeMaskedExpandLoad (const DataLayout &DL, CallInst *CI,
700+ static void scalarizeMaskedExpandLoad (const DataLayout &DL,
701+ bool HasBranchDivergence, CallInst *CI,
685702 DomTreeUpdater *DTU, bool &ModifiedDT) {
686703 Value *Ptr = CI->getArgOperand (0 );
687704 Value *Mask = CI->getArgOperand (1 );
@@ -738,23 +755,27 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
738755 }
739756
740757 // If the mask is not v1i1, use scalar bit test operations. This generates
741- // better results on X86 at least.
742- Value *SclrMask;
743- if (VectorWidth != 1 ) {
758+ // better results on X86 at least. However, don't do this on GPUs or other
759+ // machines with branch divergence, as there, each i1 takes up a register.
760+ Value *SclrMask = nullptr ;
761+ if (VectorWidth != 1 && !HasBranchDivergence) {
744762 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
745763 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
746764 }
747765
748766 for (unsigned Idx = 0 ; Idx < VectorWidth; ++Idx) {
749767 // Fill the "else" block, created in the previous iteration
750768 //
751- // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
752- // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
753- // br i1 %mask_1, label %cond.load, label %else
769+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770+ // %else ] % mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771+ // label %cond.load, label %else
754772 //
773+ // On GPUs, use
774+ // %cond = extrectelement %mask, Idx
775+ // instead
755776
756777 Value *Predicate;
757- if (VectorWidth != 1 ) {
778+ if (SclrMask != nullptr ) {
758779 Value *Mask = Builder.getInt (APInt::getOneBitSet (
759780 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
760781 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -813,7 +834,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
813834 ModifiedDT = true ;
814835}
815836
816- static void scalarizeMaskedCompressStore (const DataLayout &DL, CallInst *CI,
837+ static void scalarizeMaskedCompressStore (const DataLayout &DL,
838+ bool HasBranchDivergence, CallInst *CI,
817839 DomTreeUpdater *DTU,
818840 bool &ModifiedDT) {
819841 Value *Src = CI->getArgOperand (0 );
@@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
855877 }
856878
857879 // If the mask is not v1i1, use scalar bit test operations. This generates
858- // better results on X86 at least.
859- Value *SclrMask;
860- if (VectorWidth != 1 ) {
880+ // better results on X86 at least. However, don't do this on GPUs or other
881+ // machines with branch divergence, as there, each i1 takes up a register.
882+ Value *SclrMask = nullptr ;
883+ if (VectorWidth != 1 && !HasBranchDivergence) {
861884 Type *SclrMaskTy = Builder.getIntNTy (VectorWidth);
862885 SclrMask = Builder.CreateBitCast (Mask, SclrMaskTy, " scalar_mask" );
863886 }
@@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
868891 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
869892 // br i1 %mask_1, label %cond.store, label %else
870893 //
894+ // On GPUs, use
895+ // %cond = extrectelement %mask, Idx
896+ // instead
871897 Value *Predicate;
872- if (VectorWidth != 1 ) {
898+ if (SclrMask != nullptr ) {
873899 Value *Mask = Builder.getInt (APInt::getOneBitSet (
874900 VectorWidth, adjustForEndian (DL, VectorWidth, Idx)));
875901 Predicate = Builder.CreateICmpNE (Builder.CreateAnd (SclrMask, Mask),
@@ -993,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
9931019 bool EverMadeChange = false ;
9941020 bool MadeChange = true ;
9951021 auto &DL = F.getDataLayout ();
1022+ bool HasBranchDivergence = TTI.hasBranchDivergence (&F);
9961023 while (MadeChange) {
9971024 MadeChange = false ;
9981025 for (BasicBlock &BB : llvm::make_early_inc_range (F)) {
9991026 bool ModifiedDTOnIteration = false ;
10001027 MadeChange |= optimizeBlock (BB, ModifiedDTOnIteration, TTI, DL,
1001- DTU ? &*DTU : nullptr );
1028+ HasBranchDivergence, DTU ? &*DTU : nullptr );
10021029
10031030 // Restart BB iteration if the dominator tree of the Function was changed
10041031 if (ModifiedDTOnIteration)
@@ -1032,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
10321059
10331060static bool optimizeBlock (BasicBlock &BB, bool &ModifiedDT,
10341061 const TargetTransformInfo &TTI, const DataLayout &DL,
1035- DomTreeUpdater *DTU) {
1062+ bool HasBranchDivergence, DomTreeUpdater *DTU) {
10361063 bool MadeChange = false ;
10371064
10381065 BasicBlock::iterator CurInstIterator = BB.begin ();
10391066 while (CurInstIterator != BB.end ()) {
10401067 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1041- MadeChange |= optimizeCallInst (CI, ModifiedDT, TTI, DL, DTU);
1068+ MadeChange |=
1069+ optimizeCallInst (CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
10421070 if (ModifiedDT)
10431071 return true ;
10441072 }
@@ -1048,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
10481076
10491077static bool optimizeCallInst (CallInst *CI, bool &ModifiedDT,
10501078 const TargetTransformInfo &TTI,
1051- const DataLayout &DL, DomTreeUpdater *DTU) {
1079+ const DataLayout &DL, bool HasBranchDivergence,
1080+ DomTreeUpdater *DTU) {
10521081 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
10531082 if (II) {
10541083 // The scalarization code below does not work for scalable vectors.
@@ -1071,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
10711100 CI->getType (),
10721101 cast<ConstantInt>(CI->getArgOperand (1 ))->getAlignValue ()))
10731102 return false ;
1074- scalarizeMaskedLoad (DL, CI, DTU, ModifiedDT);
1103+ scalarizeMaskedLoad (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10751104 return true ;
10761105 case Intrinsic::masked_store:
10771106 if (TTI.isLegalMaskedStore (
10781107 CI->getArgOperand (0 )->getType (),
10791108 cast<ConstantInt>(CI->getArgOperand (2 ))->getAlignValue ()))
10801109 return false ;
1081- scalarizeMaskedStore (DL, CI, DTU, ModifiedDT);
1110+ scalarizeMaskedStore (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10821111 return true ;
10831112 case Intrinsic::masked_gather: {
10841113 MaybeAlign MA =
@@ -1089,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
10891118 if (TTI.isLegalMaskedGather (LoadTy, Alignment) &&
10901119 !TTI.forceScalarizeMaskedGather (cast<VectorType>(LoadTy), Alignment))
10911120 return false ;
1092- scalarizeMaskedGather (DL, CI, DTU, ModifiedDT);
1121+ scalarizeMaskedGather (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
10931122 return true ;
10941123 }
10951124 case Intrinsic::masked_scatter: {
@@ -1102,22 +1131,23 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
11021131 !TTI.forceScalarizeMaskedScatter (cast<VectorType>(StoreTy),
11031132 Alignment))
11041133 return false ;
1105- scalarizeMaskedScatter (DL, CI, DTU, ModifiedDT);
1134+ scalarizeMaskedScatter (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
11061135 return true ;
11071136 }
11081137 case Intrinsic::masked_expandload:
11091138 if (TTI.isLegalMaskedExpandLoad (
11101139 CI->getType (),
11111140 CI->getAttributes ().getParamAttrs (0 ).getAlignment ().valueOrOne ()))
11121141 return false ;
1113- scalarizeMaskedExpandLoad (DL, CI, DTU, ModifiedDT);
1142+ scalarizeMaskedExpandLoad (DL, HasBranchDivergence, CI, DTU, ModifiedDT);
11141143 return true ;
11151144 case Intrinsic::masked_compressstore:
11161145 if (TTI.isLegalMaskedCompressStore (
11171146 CI->getArgOperand (0 )->getType (),
11181147 CI->getAttributes ().getParamAttrs (1 ).getAlignment ().valueOrOne ()))
11191148 return false ;
1120- scalarizeMaskedCompressStore (DL, CI, DTU, ModifiedDT);
1149+ scalarizeMaskedCompressStore (DL, HasBranchDivergence, CI, DTU,
1150+ ModifiedDT);
11211151 return true ;
11221152 }
11231153 }
0 commit comments