1313
1414#include " InstCombineInternal.h"
1515#include " llvm/ADT/APInt.h"
16+ #include " llvm/ADT/SmallPtrSet.h"
1617#include " llvm/ADT/SmallVector.h"
1718#include " llvm/Analysis/InstructionSimplify.h"
1819#include " llvm/Analysis/ValueTracking.h"
@@ -657,6 +658,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
657658 return nullptr ;
658659}
659660
661+ // If we have the following pattern,
662+ // X = 1.0/sqrt(a)
663+ // R1 = X * X
664+ // R2 = a/sqrt(a)
665+ // then this method collects all the instructions that match R1 and R2.
666+ static bool getFSqrtDivOptPattern (Instruction *Div,
667+ SmallPtrSetImpl<Instruction *> &R1,
668+ SmallPtrSetImpl<Instruction *> &R2) {
669+ Value *A;
670+ if (match (Div, m_FDiv (m_FPOne (), m_Sqrt (m_Value (A)))) ||
671+ match (Div, m_FDiv (m_SpecificFP (-1.0 ), m_Sqrt (m_Value (A))))) {
672+ for (User *U : Div->users ()) {
673+ Instruction *I = cast<Instruction>(U);
674+ if (match (I, m_FMul (m_Specific (Div), m_Specific (Div))))
675+ R1.insert (I);
676+ }
677+
678+ CallInst *CI = cast<CallInst>(Div->getOperand (1 ));
679+ for (User *U : CI->users ()) {
680+ Instruction *I = cast<Instruction>(U);
681+ if (match (I, m_FDiv (m_Specific (A), m_Sqrt (m_Specific (A)))))
682+ R2.insert (I);
683+ }
684+ }
685+ return !R1.empty () && !R2.empty ();
686+ }
687+
688+ // Check legality for transforming
689+ // x = 1.0/sqrt(a)
690+ // r1 = x * x;
691+ // r2 = a/sqrt(a);
692+ //
693+ // TO
694+ //
695+ // r1 = 1/a
696+ // r2 = sqrt(a)
697+ // x = r1 * r2
698+ // This transform works only when 'a' is known positive.
699+ static bool isFSqrtDivToFMulLegal (Instruction *X,
700+ SmallPtrSetImpl<Instruction *> &R1,
701+ SmallPtrSetImpl<Instruction *> &R2) {
702+ // Check if the required pattern for the transformation exists.
703+ if (!getFSqrtDivOptPattern (X, R1, R2))
704+ return false ;
705+
706+ BasicBlock *BBx = X->getParent ();
707+ BasicBlock *BBr1 = (*R1.begin ())->getParent ();
708+ BasicBlock *BBr2 = (*R2.begin ())->getParent ();
709+
710+ CallInst *FSqrt = cast<CallInst>(X->getOperand (1 ));
711+ if (!FSqrt->hasAllowReassoc () || !FSqrt->hasNoNaNs () ||
712+ !FSqrt->hasNoSignedZeros () || !FSqrt->hasNoInfs ())
713+ return false ;
714+
715+ // We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
716+ // by recip fp as it is strictly meant to transform ops of type a/b to
717+ // a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
718+ // has been used(rather abused)in the past for algebraic rewrites.
719+ if (!X->hasAllowReassoc () || !X->hasAllowReciprocal () || !X->hasNoInfs ())
720+ return false ;
721+
722+ // Check the constraints on X, R1 and R2 combined.
723+ // fdiv instruction and one of the multiplications must reside in the same
724+ // block. If not, the optimized code may execute more ops than before and
725+ // this may hamper the performance.
726+ if (BBx != BBr1 && BBx != BBr2)
727+ return false ;
728+
729+ // Check the constraints on instructions in R1.
730+ if (any_of (R1, [BBr1](Instruction *I) {
731+ // When you have multiple instructions residing in R1 and R2
732+ // respectively, it's difficult to generate combinations of (R1,R2) and
733+ // then check if we have the required pattern. So, for now, just be
734+ // conservative.
735+ return (I->getParent () != BBr1 || !I->hasAllowReassoc ());
736+ }))
737+ return false ;
738+
739+ // Check the constraints on instructions in R2.
740+ return all_of (R2, [BBr2](Instruction *I) {
741+ // When you have multiple instructions residing in R1 and R2
742+ // respectively, it's difficult to generate combination of (R1,R2) and
743+ // then check if we have the required pattern. So, for now, just be
744+ // conservative.
745+ return (I->getParent () == BBr2 && I->hasAllowReassoc ());
746+ });
747+ }
748+
660749Instruction *InstCombinerImpl::foldFMulReassoc (BinaryOperator &I) {
661750 Value *Op0 = I.getOperand (0 );
662751 Value *Op1 = I.getOperand (1 );
@@ -1913,6 +2002,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
19132002 return BinaryOperator::CreateFMulFMF (Op0, NewSqrt, &I);
19142003}
19152004
2005+ // Change
2006+ // X = 1/sqrt(a)
2007+ // R1 = X * X
2008+ // R2 = a * X
2009+ //
2010+ // TO
2011+ //
2012+ // FDiv = 1/a
2013+ // FSqrt = sqrt(a)
2014+ // FMul = FDiv * FSqrt
2015+ // Replace Uses Of R1 With FDiv
2016+ // Replace Uses Of R2 With FSqrt
2017+ // Replace Uses Of X With FMul
2018+ static Instruction *
2019+ convertFSqrtDivIntoFMul (CallInst *CI, Instruction *X,
2020+ const SmallPtrSetImpl<Instruction *> &R1,
2021+ const SmallPtrSetImpl<Instruction *> &R2,
2022+ InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2023+
2024+ B.SetInsertPoint (X);
2025+
2026+ // Have an instruction that is representative of all of instructions in R1 and
2027+ // get the most common fpmath metadata and fast-math flags on it.
2028+ Value *SqrtOp = CI->getArgOperand (0 );
2029+ auto *FDiv = cast<Instruction>(
2030+ B.CreateFDiv (ConstantFP::get (X->getType (), 1.0 ), SqrtOp));
2031+ auto *R1FPMathMDNode = (*R1.begin ())->getMetadata (LLVMContext::MD_fpmath);
2032+ FastMathFlags R1FMF = (*R1.begin ())->getFastMathFlags (); // Common FMF
2033+ for (Instruction *I : R1) {
2034+ R1FPMathMDNode = MDNode::getMostGenericFPMath (
2035+ R1FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2036+ R1FMF &= I->getFastMathFlags ();
2037+ IC->replaceInstUsesWith (*I, FDiv);
2038+ IC->eraseInstFromFunction (*I);
2039+ }
2040+ FDiv->setMetadata (LLVMContext::MD_fpmath, R1FPMathMDNode);
2041+ FDiv->copyFastMathFlags (R1FMF);
2042+
2043+ // Have a single sqrt call instruction that is representative of all of
2044+ // instructions in R2 and get the most common fpmath metadata and fast-math
2045+ // flags on it.
2046+ auto *FSqrt = cast<CallInst>(CI->clone ());
2047+ FSqrt->insertBefore (CI);
2048+ auto *R2FPMathMDNode = (*R2.begin ())->getMetadata (LLVMContext::MD_fpmath);
2049+ FastMathFlags R2FMF = (*R2.begin ())->getFastMathFlags (); // Common FMF
2050+ for (Instruction *I : R2) {
2051+ R2FPMathMDNode = MDNode::getMostGenericFPMath (
2052+ R2FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2053+ R2FMF &= I->getFastMathFlags ();
2054+ IC->replaceInstUsesWith (*I, FSqrt);
2055+ IC->eraseInstFromFunction (*I);
2056+ }
2057+ FSqrt->setMetadata (LLVMContext::MD_fpmath, R2FPMathMDNode);
2058+ FSqrt->copyFastMathFlags (R2FMF);
2059+
2060+ Instruction *FMul;
2061+ // If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2062+ if (match (X, m_FDiv (m_SpecificFP (-1.0 ), m_Specific (CI)))) {
2063+ Value *Mul = B.CreateFMul (FDiv, FSqrt);
2064+ FMul = cast<Instruction>(B.CreateFNeg (Mul));
2065+ } else
2066+ FMul = cast<Instruction>(B.CreateFMul (FDiv, FSqrt));
2067+ FMul->copyMetadata (*X);
2068+ FMul->copyFastMathFlags (FastMathFlags::intersectRewrite (R1FMF, R2FMF) |
2069+ FastMathFlags::unionValue (R1FMF, R2FMF));
2070+ IC->replaceInstUsesWith (*X, FMul);
2071+ return IC->eraseInstFromFunction (*X);
2072+ }
2073+
19162074Instruction *InstCombinerImpl::visitFDiv (BinaryOperator &I) {
19172075 Module *M = I.getModule ();
19182076
@@ -1937,6 +2095,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19372095 return R;
19382096
19392097 Value *Op0 = I.getOperand (0 ), *Op1 = I.getOperand (1 );
2098+
2099+ // Convert
2100+ // x = 1.0/sqrt(a)
2101+ // r1 = x * x;
2102+ // r2 = a/sqrt(a);
2103+ //
2104+ // TO
2105+ //
2106+ // r1 = 1/a
2107+ // r2 = sqrt(a)
2108+ // x = r1 * r2
2109+ SmallPtrSet<Instruction *, 2 > R1, R2;
2110+ if (isFSqrtDivToFMulLegal (&I, R1, R2)) {
2111+ CallInst *CI = cast<CallInst>(I.getOperand (1 ));
2112+ if (Instruction *D = convertFSqrtDivIntoFMul (CI, &I, R1, R2, Builder, this ))
2113+ return D;
2114+ }
2115+
19402116 if (isa<Constant>(Op0))
19412117 if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
19422118 if (Instruction *R = FoldOpIntoSelect (I, SI))
0 commit comments