1313
1414#include " InstCombineInternal.h"
1515#include " llvm/ADT/APInt.h"
16+ #include " llvm/ADT/SmallPtrSet.h"
1617#include " llvm/ADT/SmallVector.h"
1718#include " llvm/Analysis/InstructionSimplify.h"
1819#include " llvm/Analysis/ValueTracking.h"
@@ -666,6 +667,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
666667 return nullptr ;
667668}
668669
670+ // If we have the following pattern,
671+ // X = 1.0/sqrt(a)
672+ // R1 = X * X
673+ // R2 = a/sqrt(a)
674+ // then this method collects all the instructions that match R1 and R2.
675+ static bool getFSqrtDivOptPattern (Instruction *Div,
676+ SmallPtrSetImpl<Instruction *> &R1,
677+ SmallPtrSetImpl<Instruction *> &R2) {
678+ Value *A;
679+ if (match (Div, m_FDiv (m_FPOne (), m_Sqrt (m_Value (A)))) ||
680+ match (Div, m_FDiv (m_SpecificFP (-1.0 ), m_Sqrt (m_Value (A))))) {
681+ for (User *U : Div->users ()) {
682+ Instruction *I = cast<Instruction>(U);
683+ if (match (I, m_FMul (m_Specific (Div), m_Specific (Div))))
684+ R1.insert (I);
685+ }
686+
687+ CallInst *CI = cast<CallInst>(Div->getOperand (1 ));
688+ for (User *U : CI->users ()) {
689+ Instruction *I = cast<Instruction>(U);
690+ if (match (I, m_FDiv (m_Specific (A), m_Sqrt (m_Specific (A)))))
691+ R2.insert (I);
692+ }
693+ }
694+ return !R1.empty () && !R2.empty ();
695+ }
696+
697+ // Check legality for transforming
698+ // x = 1.0/sqrt(a)
699+ // r1 = x * x;
700+ // r2 = a/sqrt(a);
701+ //
702+ // TO
703+ //
704+ // r1 = 1/a
705+ // r2 = sqrt(a)
706+ // x = r1 * r2
707+ // This transform works only when 'a' is known positive.
708+ static bool isFSqrtDivToFMulLegal (Instruction *X,
709+ SmallPtrSetImpl<Instruction *> &R1,
710+ SmallPtrSetImpl<Instruction *> &R2) {
711+ // Check if the required pattern for the transformation exists.
712+ if (!getFSqrtDivOptPattern (X, R1, R2))
713+ return false ;
714+
715+ BasicBlock *BBx = X->getParent ();
716+ BasicBlock *BBr1 = (*R1.begin ())->getParent ();
717+ BasicBlock *BBr2 = (*R2.begin ())->getParent ();
718+
719+ CallInst *FSqrt = cast<CallInst>(X->getOperand (1 ));
720+ if (!FSqrt->hasAllowReassoc () || !FSqrt->hasNoNaNs () ||
721+ !FSqrt->hasNoSignedZeros () || !FSqrt->hasNoInfs ())
722+ return false ;
723+
724+ // We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
725+ // by recip fp as it is strictly meant to transform ops of type a/b to
726+ // a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
727+ // has been used(rather abused)in the past for algebraic rewrites.
728+ if (!X->hasAllowReassoc () || !X->hasAllowReciprocal () || !X->hasNoInfs ())
729+ return false ;
730+
731+ // Check the constraints on X, R1 and R2 combined.
732+ // fdiv instruction and one of the multiplications must reside in the same
733+ // block. If not, the optimized code may execute more ops than before and
734+ // this may hamper the performance.
735+ if (BBx != BBr1 && BBx != BBr2)
736+ return false ;
737+
738+ // Check the constraints on instructions in R1.
739+ if (any_of (R1, [BBr1](Instruction *I) {
740+ // When you have multiple instructions residing in R1 and R2
741+ // respectively, it's difficult to generate combinations of (R1,R2) and
742+ // then check if we have the required pattern. So, for now, just be
743+ // conservative.
744+ return (I->getParent () != BBr1 || !I->hasAllowReassoc ());
745+ }))
746+ return false ;
747+
748+ // Check the constraints on instructions in R2.
749+ return all_of (R2, [BBr2](Instruction *I) {
750+ // When you have multiple instructions residing in R1 and R2
751+ // respectively, it's difficult to generate combination of (R1,R2) and
752+ // then check if we have the required pattern. So, for now, just be
753+ // conservative.
754+ return (I->getParent () == BBr2 && I->hasAllowReassoc ());
755+ });
756+ }
757+
669758Instruction *InstCombinerImpl::foldFMulReassoc (BinaryOperator &I) {
670759 Value *Op0 = I.getOperand (0 );
671760 Value *Op1 = I.getOperand (1 );
@@ -1917,6 +2006,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
19172006 return BinaryOperator::CreateFMulFMF (Op0, NewSqrt, &I);
19182007}
19192008
2009+ // Change
2010+ // X = 1/sqrt(a)
2011+ // R1 = X * X
2012+ // R2 = a * X
2013+ //
2014+ // TO
2015+ //
2016+ // FDiv = 1/a
2017+ // FSqrt = sqrt(a)
2018+ // FMul = FDiv * FSqrt
2019+ // Replace Uses Of R1 With FDiv
2020+ // Replace Uses Of R2 With FSqrt
2021+ // Replace Uses Of X With FMul
2022+ static Instruction *
2023+ convertFSqrtDivIntoFMul (CallInst *CI, Instruction *X,
2024+ const SmallPtrSetImpl<Instruction *> &R1,
2025+ const SmallPtrSetImpl<Instruction *> &R2,
2026+ InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2027+
2028+ B.SetInsertPoint (X);
2029+
2030+ // Have an instruction that is representative of all of instructions in R1 and
2031+ // get the most common fpmath metadata and fast-math flags on it.
2032+ Value *SqrtOp = CI->getArgOperand (0 );
2033+ auto *FDiv = cast<Instruction>(
2034+ B.CreateFDiv (ConstantFP::get (X->getType (), 1.0 ), SqrtOp));
2035+ auto *R1FPMathMDNode = (*R1.begin ())->getMetadata (LLVMContext::MD_fpmath);
2036+ FastMathFlags R1FMF = (*R1.begin ())->getFastMathFlags (); // Common FMF
2037+ for (Instruction *I : R1) {
2038+ R1FPMathMDNode = MDNode::getMostGenericFPMath (
2039+ R1FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2040+ R1FMF &= I->getFastMathFlags ();
2041+ IC->replaceInstUsesWith (*I, FDiv);
2042+ IC->eraseInstFromFunction (*I);
2043+ }
2044+ FDiv->setMetadata (LLVMContext::MD_fpmath, R1FPMathMDNode);
2045+ FDiv->copyFastMathFlags (R1FMF);
2046+
2047+ // Have a single sqrt call instruction that is representative of all of
2048+ // instructions in R2 and get the most common fpmath metadata and fast-math
2049+ // flags on it.
2050+ auto *FSqrt = cast<CallInst>(CI->clone ());
2051+ FSqrt->insertBefore (CI);
2052+ auto *R2FPMathMDNode = (*R2.begin ())->getMetadata (LLVMContext::MD_fpmath);
2053+ FastMathFlags R2FMF = (*R2.begin ())->getFastMathFlags (); // Common FMF
2054+ for (Instruction *I : R2) {
2055+ R2FPMathMDNode = MDNode::getMostGenericFPMath (
2056+ R2FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2057+ R2FMF &= I->getFastMathFlags ();
2058+ IC->replaceInstUsesWith (*I, FSqrt);
2059+ IC->eraseInstFromFunction (*I);
2060+ }
2061+ FSqrt->setMetadata (LLVMContext::MD_fpmath, R2FPMathMDNode);
2062+ FSqrt->copyFastMathFlags (R2FMF);
2063+
2064+ Instruction *FMul;
2065+ // If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2066+ if (match (X, m_FDiv (m_SpecificFP (-1.0 ), m_Specific (CI)))) {
2067+ Value *Mul = B.CreateFMul (FDiv, FSqrt);
2068+ FMul = cast<Instruction>(B.CreateFNeg (Mul));
2069+ } else
2070+ FMul = cast<Instruction>(B.CreateFMul (FDiv, FSqrt));
2071+ FMul->copyMetadata (*X);
2072+ FMul->copyFastMathFlags (FastMathFlags::intersectRewrite (R1FMF, R2FMF) |
2073+ FastMathFlags::unionValue (R1FMF, R2FMF));
2074+ IC->replaceInstUsesWith (*X, FMul);
2075+ return IC->eraseInstFromFunction (*X);
2076+ }
2077+
19202078Instruction *InstCombinerImpl::visitFDiv (BinaryOperator &I) {
19212079 Module *M = I.getModule ();
19222080
@@ -1941,6 +2099,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19412099 return R;
19422100
19432101 Value *Op0 = I.getOperand (0 ), *Op1 = I.getOperand (1 );
2102+
2103+ // Convert
2104+ // x = 1.0/sqrt(a)
2105+ // r1 = x * x;
2106+ // r2 = a/sqrt(a);
2107+ //
2108+ // TO
2109+ //
2110+ // r1 = 1/a
2111+ // r2 = sqrt(a)
2112+ // x = r1 * r2
2113+ SmallPtrSet<Instruction *, 2 > R1, R2;
2114+ if (isFSqrtDivToFMulLegal (&I, R1, R2)) {
2115+ CallInst *CI = cast<CallInst>(I.getOperand (1 ));
2116+ if (Instruction *D = convertFSqrtDivIntoFMul (CI, &I, R1, R2, Builder, this ))
2117+ return D;
2118+ }
2119+
19442120 if (isa<Constant>(Op0))
19452121 if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
19462122 if (Instruction *R = FoldOpIntoSelect (I, SI))
0 commit comments