1212// ===----------------------------------------------------------------------===//
1313
1414#include " llvm/CodeGen/ExpandReductions.h"
15+ #include " llvm/Analysis/DomTreeUpdater.h"
1516#include " llvm/Analysis/TargetTransformInfo.h"
1617#include " llvm/CodeGen/Passes.h"
18+ #include " llvm/IR/BasicBlock.h"
19+ #include " llvm/IR/Constants.h"
20+ #include " llvm/IR/DerivedTypes.h"
21+ #include " llvm/IR/Dominators.h"
1722#include " llvm/IR/IRBuilder.h"
1823#include " llvm/IR/InstIterator.h"
24+ #include " llvm/IR/Instruction.h"
1925#include " llvm/IR/IntrinsicInst.h"
2026#include " llvm/IR/Intrinsics.h"
2127#include " llvm/InitializePasses.h"
2228#include " llvm/Pass.h"
29+ #include " llvm/Support/ErrorHandling.h"
30+ #include " llvm/Support/MathExtras.h"
31+ #include " llvm/Transforms/Utils/BasicBlockUtils.h"
2332#include " llvm/Transforms/Utils/LoopUtils.h"
33+ #include < optional>
2434
2535using namespace llvm ;
2636
2737namespace {
2838
29- bool expandReductions (Function &F, const TargetTransformInfo *TTI) {
30- bool Changed = false ;
39+ void updateDomTreeForScalableExpansion (DominatorTree *DT, BasicBlock *Preheader,
40+ BasicBlock *Loop, BasicBlock *Exit) {
41+ DT->addNewBlock (Loop, Preheader);
42+ DT->changeImmediateDominator (Exit, Loop);
43+ assert (DT->verify (DominatorTree::VerificationLevel::Fast));
44+ }
45+
46+ // / Expand a reduction on a scalable vector into a loop
47+ // / that iterates over one element after the other.
48+ Value *expandScalableReduction (IRBuilderBase &Builder, IntrinsicInst *II,
49+ Value *Acc, Value *Vec,
50+ Instruction::BinaryOps BinOp,
51+ DominatorTree *DT) {
52+ ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType ());
53+
54+ // Split the original BB in two and create a new BB between them,
55+ // which will be a loop.
56+ BasicBlock *BeforeBB = II->getParent ();
57+ BasicBlock *AfterBB = SplitBlock (BeforeBB, II, DT);
58+ BasicBlock *LoopBB = BasicBlock::Create (Builder.getContext (), " rdx.loop" ,
59+ BeforeBB->getParent (), AfterBB);
60+ BeforeBB->getTerminator ()->setSuccessor (0 , LoopBB);
61+
62+ // Calculate the number of elements in the vector:
63+ Builder.SetInsertPoint (BeforeBB->getTerminator ());
64+ Value *NumElts =
65+ Builder.CreateVScale (Builder.getInt64 (VecTy->getMinNumElements ()));
66+
67+ // Create two PHIs, one for the index of the current lane and one for
68+ // the reduction.
69+ Builder.SetInsertPoint (LoopBB);
70+ PHINode *IV = Builder.CreatePHI (Builder.getInt64Ty (), 2 , " index" );
71+ IV->addIncoming (Builder.getInt64 (0 ), BeforeBB);
72+ PHINode *RdxPhi = Builder.CreatePHI (VecTy->getScalarType (), 2 , " rdx.phi" );
73+ RdxPhi->addIncoming (Acc, BeforeBB);
74+
75+ Value *IVInc =
76+ Builder.CreateAdd (IV, Builder.getInt64 (1 ), " index.next" , true , true );
77+ IV->addIncoming (IVInc, LoopBB);
78+
79+ // Extract the value at the current lane from the vector and perform
80+ // the scalar reduction binop:
81+ Value *Lane = Builder.CreateExtractElement (Vec, IV, " elm" );
82+ Value *Rdx = Builder.CreateBinOp (BinOp, RdxPhi, Lane, " rdx" );
83+ RdxPhi->addIncoming (Rdx, LoopBB);
84+
85+ // Exit when all lanes have been treated (assuming there will be at least
86+ // one element in the vector):
87+ Value *Done = Builder.CreateCmp (CmpInst::ICMP_EQ, IVInc, NumElts, " exitcond" );
88+ Builder.CreateCondBr (Done, AfterBB, LoopBB);
89+
90+ if (DT)
91+ updateDomTreeForScalableExpansion (DT, BeforeBB, LoopBB, AfterBB);
92+
93+ return Rdx;
94+ }
95+
96+ // / Expand a reduction on a scalable vector in a parallel-tree like
97+ // / manner, meaning halving the number of elements to treat in every
98+ // / iteration.
99+ Value *expandScalableTreeReduction (
100+ IRBuilderBase &Builder, IntrinsicInst *II, std::optional<Value *> Acc,
101+ Value *Vec, Instruction::BinaryOps BinOp,
102+ function_ref<bool (Constant *)> IsNeutralElement, DominatorTree *DT,
103+ std::optional<unsigned> FixedVScale) {
104+ ScalableVectorType *VecTy = cast<ScalableVectorType>(Vec->getType ());
105+ ScalableVectorType *VecTyX2 = ScalableVectorType::get (
106+ VecTy->getScalarType (), VecTy->getMinNumElements () * 2 );
107+
108+ // If the VScale is fixed, do not generate a loop, and instead to
109+ // something similar to llvm::getShuffleReduction(). That function
110+ // cannot be used directly because it uses shuffle masks, which
111+ // are not avaiable for scalable vectors (even if vscale is fixed).
112+ // The approach is effectively the same.
113+ if (FixedVScale.has_value ()) {
114+ unsigned VF = VecTy->getMinNumElements () * FixedVScale.value ();
115+ assert (isPowerOf2_64 (VF));
116+ for (unsigned I = VF; I != 1 ; I >>= 1 ) {
117+ Value *Extended = Builder.CreateInsertVector (
118+ VecTyX2, PoisonValue::get (VecTyX2), Vec, Builder.getInt64 (0 ));
119+ Value *Pair = Builder.CreateIntrinsic (Intrinsic::vector_deinterleave2,
120+ {VecTyX2}, {Extended});
121+
122+ Value *Vec1 = Builder.CreateExtractValue (Pair, {0 });
123+ Value *Vec2 = Builder.CreateExtractValue (Pair, {1 });
124+ Vec = Builder.CreateBinOp (BinOp, Vec1, Vec2, " rdx" );
125+ }
126+ Value *FinalVal = Builder.CreateExtractElement (Vec, uint64_t (0 ));
127+ if (Acc)
128+ if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement (C))
129+ FinalVal = Builder.CreateBinOp (BinOp, *Acc, FinalVal, " rdx.final" );
130+ return FinalVal;
131+ }
132+
133+ // Split the original BB in two and create a new BB between them,
134+ // which will be a loop.
135+ BasicBlock *BeforeBB = II->getParent ();
136+ BasicBlock *AfterBB = SplitBlock (BeforeBB, II, DT);
137+ BasicBlock *LoopBB = BasicBlock::Create (Builder.getContext (), " rdx.loop" ,
138+ BeforeBB->getParent (), AfterBB);
139+ BeforeBB->getTerminator ()->setSuccessor (0 , LoopBB);
140+
141+ // This tree reduction only needs to do log2(N) iterations.
142+ // Note: Calculating log2(N) using count-trailing-zeros (cttz) only works if
143+ // `vscale` the vector size is a power of two.
144+ Builder.SetInsertPoint (BeforeBB->getTerminator ());
145+ Value *NumElts =
146+ Builder.CreateVScale (Builder.getInt64 (VecTy->getMinNumElements ()));
147+ Value *NumIters = Builder.CreateIntrinsic (NumElts->getType (), Intrinsic::cttz,
148+ {NumElts, Builder.getTrue ()});
149+
150+ // Create two PHIs, one for the IV and one for the reduction.
151+ Builder.SetInsertPoint (LoopBB);
152+ PHINode *IV = Builder.CreatePHI (Builder.getInt64Ty (), 2 , " iter" );
153+ IV->addIncoming (Builder.getInt64 (0 ), BeforeBB);
154+ PHINode *VecPhi = Builder.CreatePHI (VecTy, 2 , " rdx.phi" );
155+ VecPhi->addIncoming (Vec, BeforeBB);
156+
157+ Value *IVInc =
158+ Builder.CreateAdd (IV, Builder.getInt64 (1 ), " iter.next" , true , true );
159+ IV->addIncoming (IVInc, LoopBB);
160+
161+ // The deinterleave intrinsic takes a vector of, for example, type
162+ // <vscale x 8 x float> and produces a pair of vectors with half the size,
163+ // so 2 x <vscale x 4 x float>. An insert vector operation is used to
164+ // create a double-sized vector where the upper half is poison, because
165+ // we never care about that upper half anyways!
166+ Value *Extended = Builder.CreateInsertVector (
167+ VecTyX2, PoisonValue::get (VecTyX2), VecPhi, Builder.getInt64 (0 ));
168+ Value *Pair = Builder.CreateIntrinsic (Intrinsic::vector_deinterleave2,
169+ {VecTyX2}, {Extended});
170+ Value *Vec1 = Builder.CreateExtractValue (Pair, {0 });
171+ Value *Vec2 = Builder.CreateExtractValue (Pair, {1 });
172+ Value *Rdx = Builder.CreateBinOp (BinOp, Vec1, Vec2, " rdx" );
173+ VecPhi->addIncoming (Rdx, LoopBB);
174+
175+ // Reduction-loop exit condition:
176+ Value *Done =
177+ Builder.CreateCmp (CmpInst::ICMP_EQ, IVInc, NumIters, " exitcond" );
178+ Builder.CreateCondBr (Done, AfterBB, LoopBB);
179+ Builder.SetInsertPoint (AfterBB, AfterBB->getFirstInsertionPt ());
180+ Value *FinalVal = Builder.CreateExtractElement (Rdx, uint64_t (0 ));
181+
182+ // If the Acc value is not the neutral element of the reduction operation,
183+ // then we need to do the binop one last time with the end result of the
184+ // tree reduction.
185+ if (Acc)
186+ if (auto *C = dyn_cast<Constant>(*Acc); !C || !IsNeutralElement (C))
187+ FinalVal = Builder.CreateBinOp (BinOp, *Acc, FinalVal, " rdx.final" );
188+
189+ if (DT)
190+ updateDomTreeForScalableExpansion (DT, BeforeBB, LoopBB, AfterBB);
191+
192+ return FinalVal;
193+ }
194+
195+ std::pair<bool , bool > expandReductions (Function &F,
196+ const TargetTransformInfo *TTI,
197+ DominatorTree *DT) {
198+ bool Changed = false , CFGChanged = false ;
31199 SmallVector<IntrinsicInst *, 4 > Worklist;
32200 for (auto &I : instructions (F)) {
33201 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
@@ -54,6 +222,12 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
54222 }
55223 }
56224
225+ const auto &Attrs = F.getAttributes ().getFnAttrs ();
226+ unsigned MinVScale = Attrs.getVScaleRangeMin ();
227+ std::optional<unsigned > FixedVScale = Attrs.getVScaleRangeMax ();
228+ if (FixedVScale != MinVScale)
229+ FixedVScale = std::nullopt ;
230+
57231 for (auto *II : Worklist) {
58232 FastMathFlags FMF =
59233 isa<FPMathOperator>(II) ? II->getFastMathFlags () : FastMathFlags{};
@@ -74,7 +248,34 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
74248 // and it can't be handled by generating a shuffle sequence.
75249 Value *Acc = II->getArgOperand (0 );
76250 Value *Vec = II->getArgOperand (1 );
77- unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
251+ auto RdxOpcode =
252+ Instruction::BinaryOps (getArithmeticReductionInstruction (ID));
253+
254+ bool ScalableTy = Vec->getType ()->isScalableTy ();
255+ if (ScalableTy && (!FixedVScale || FMF.allowReassoc ())) {
256+ CFGChanged |= !FixedVScale;
257+ assert (TTI->isVScaleKnownToBeAPowerOfTwo () &&
258+ " Scalable tree reduction unimplemented for targets with a "
259+ " VScale not known to be a power of 2." );
260+ if (FMF.allowReassoc ())
261+ Rdx = expandScalableTreeReduction (
262+ Builder, II, Acc, Vec, RdxOpcode,
263+ [&](Constant *C) {
264+ switch (ID) {
265+ case Intrinsic::vector_reduce_fadd:
266+ return C->isZeroValue ();
267+ case Intrinsic::vector_reduce_fmul:
268+ return C->isOneValue ();
269+ default :
270+ llvm_unreachable (" Binop not handled" );
271+ }
272+ },
273+ DT, FixedVScale);
274+ else
275+ Rdx = expandScalableReduction (Builder, II, Acc, Vec, RdxOpcode, DT);
276+ break ;
277+ }
278+
78279 if (!FMF.allowReassoc ())
79280 Rdx = getOrderedReduction (Builder, Acc, Vec, RdxOpcode, RK);
80281 else {
@@ -125,10 +326,22 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
125326 case Intrinsic::vector_reduce_umax:
126327 case Intrinsic::vector_reduce_umin: {
127328 Value *Vec = II->getArgOperand (0 );
329+ unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
330+ if (Vec->getType ()->isScalableTy ()) {
331+ CFGChanged |= !FixedVScale;
332+ assert (TTI->isVScaleKnownToBeAPowerOfTwo () &&
333+ " Scalable tree reduction unimplemented for targets with a "
334+ " VScale not known to be a power of 2." );
335+ Rdx = expandScalableTreeReduction (
336+ Builder, II, std::nullopt , Vec, Instruction::BinaryOps (RdxOpcode),
337+ [](Constant *C) -> bool { llvm_unreachable (" No accumulator!" ); },
338+ DT, FixedVScale);
339+ break ;
340+ }
341+
128342 if (!isPowerOf2_32 (
129343 cast<FixedVectorType>(Vec->getType ())->getNumElements ()))
130344 continue ;
131- unsigned RdxOpcode = getArithmeticReductionInstruction (ID);
132345 Rdx = getShuffleReduction (Builder, Vec, RdxOpcode, RS, RK);
133346 break ;
134347 }
@@ -150,7 +363,7 @@ bool expandReductions(Function &F, const TargetTransformInfo *TTI) {
150363 II->eraseFromParent ();
151364 Changed = true ;
152365 }
153- return Changed;
366+ return {CFGChanged, Changed} ;
154367}
155368
156369class ExpandReductions : public FunctionPass {
@@ -161,13 +374,15 @@ class ExpandReductions : public FunctionPass {
161374 }
162375
163376 bool runOnFunction (Function &F) override {
164- const auto *TTI =&getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
165- return expandReductions (F, TTI);
377+ const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI (F);
378+ auto *DTA = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
379+ return expandReductions (F, TTI, DTA ? &DTA->getDomTree () : nullptr ).second ;
166380 }
167381
168382 void getAnalysisUsage (AnalysisUsage &AU) const override {
169383 AU.addRequired <TargetTransformInfoWrapperPass>();
170- AU.setPreservesCFG ();
384+ AU.addUsedIfAvailable <DominatorTreeWrapperPass>();
385+ AU.addPreserved <DominatorTreeWrapperPass>();
171386 }
172387};
173388}
@@ -186,9 +401,14 @@ FunctionPass *llvm::createExpandReductionsPass() {
186401PreservedAnalyses ExpandReductionsPass::run (Function &F,
187402 FunctionAnalysisManager &AM) {
188403 const auto &TTI = AM.getResult <TargetIRAnalysis>(F);
189- if (!expandReductions (F, &TTI))
404+ auto *DT = AM.getCachedResult <DominatorTreeAnalysis>(F);
405+ auto [CFGChanged, Changed] = expandReductions (F, &TTI, DT);
406+ if (!Changed)
190407 return PreservedAnalyses::all ();
191408 PreservedAnalyses PA;
192- PA.preserveSet <CFGAnalyses>();
409+ if (!CFGChanged)
410+ PA.preserveSet <CFGAnalyses>();
411+ else
412+ PA.preserve <DominatorTreeAnalysis>();
193413 return PA;
194414}
0 commit comments