@@ -129,7 +129,9 @@ class VectorCombine {
129129 bool foldExtractedCmps (Instruction &I);
130130 bool foldBinopOfReductions (Instruction &I);
131131 bool foldSingleElementStore (Instruction &I);
132- bool scalarizeLoadExtract (Instruction &I);
132+ bool scalarizeLoad (Instruction &I);
133+ bool scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy, Value *Ptr);
134+ bool scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy, Value *Ptr);
133135 bool scalarizeExtExtract (Instruction &I);
134136 bool foldConcatOfBoolMasks (Instruction &I);
135137 bool foldPermuteOfBinops (Instruction &I);
@@ -1845,49 +1847,42 @@ bool VectorCombine::foldSingleElementStore(Instruction &I) {
18451847 return false ;
18461848}
18471849
1848- // / Try to scalarize vector loads feeding extractelement instructions.
1849- bool VectorCombine::scalarizeLoadExtract (Instruction &I) {
1850- if (!TTI.allowVectorElementIndexingUsingGEP ())
1851- return false ;
1852-
1850+ // / Try to scalarize vector loads feeding extractelement or bitcast
1851+ // / instructions.
1852+ bool VectorCombine::scalarizeLoad (Instruction &I) {
18531853 Value *Ptr;
18541854 if (!match (&I, m_Load (m_Value (Ptr))))
18551855 return false ;
18561856
18571857 auto *LI = cast<LoadInst>(&I);
18581858 auto *VecTy = cast<VectorType>(LI->getType ());
1859- if (LI->isVolatile () || !DL->typeSizeEqualsStoreSize (VecTy->getScalarType ()))
1859+ if (!VecTy || LI->isVolatile () ||
1860+ !DL->typeSizeEqualsStoreSize (VecTy->getScalarType ()))
18601861 return false ;
18611862
1862- InstructionCost OriginalCost =
1863- TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1864- LI->getPointerAddressSpace (), CostKind);
1865- InstructionCost ScalarizedCost = 0 ;
1866-
1863+ // Check what type of users we have and ensure no memory modifications betwwen
1864+ // the load and its users.
1865+ bool AllExtracts = true ;
1866+ bool AllBitcasts = true ;
18671867 Instruction *LastCheckedInst = LI;
18681868 unsigned NumInstChecked = 0 ;
1869- DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1870- auto FailureGuard = make_scope_exit ([&]() {
1871- // If the transform is aborted, discard the ScalarizationResults.
1872- for (auto &Pair : NeedFreeze)
1873- Pair.second .discard ();
1874- });
18751869
1876- // Check if all users of the load are extracts with no memory modifications
1877- // between the load and the extract. Compute the cost of both the original
1878- // code and the scalarized version.
18791870 for (User *U : LI->users ()) {
1880- auto *UI = dyn_cast<ExtractElementInst >(U);
1881- if (!UI || UI->getParent () != LI->getParent ())
1871+ auto *UI = dyn_cast<Instruction >(U);
1872+ if (!UI || UI->getParent () != LI->getParent () || UI-> use_empty () )
18821873 return false ;
18831874
1884- // If any extract is waiting to be erased, then bail out as this will
1875+ // If any user is waiting to be erased, then bail out as this will
18851876 // distort the cost calculation and possibly lead to infinite loops.
18861877 if (UI->use_empty ())
18871878 return false ;
18881879
1889- // Check if any instruction between the load and the extract may modify
1890- // memory.
1880+ if (!isa<ExtractElementInst>(UI))
1881+ AllExtracts = false ;
1882+ if (!isa<BitCastInst>(UI))
1883+ AllBitcasts = false ;
1884+
1885+ // Check if any instruction between the load and the user may modify memory.
18911886 if (LastCheckedInst->comesBefore (UI)) {
18921887 for (Instruction &I :
18931888 make_range (std::next (LI->getIterator ()), UI->getIterator ())) {
@@ -1899,6 +1894,35 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
18991894 }
19001895 LastCheckedInst = UI;
19011896 }
1897+ }
1898+
1899+ if (AllExtracts)
1900+ return scalarizeLoadExtract (LI, VecTy, Ptr);
1901+ if (AllBitcasts)
1902+ return scalarizeLoadBitcast (LI, VecTy, Ptr);
1903+ return false ;
1904+ }
1905+
1906+ // / Try to scalarize vector loads feeding extractelement instructions.
1907+ bool VectorCombine::scalarizeLoadExtract (LoadInst *LI, VectorType *VecTy,
1908+ Value *Ptr) {
1909+ if (!TTI.allowVectorElementIndexingUsingGEP ())
1910+ return false ;
1911+
1912+ InstructionCost OriginalCost =
1913+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1914+ LI->getPointerAddressSpace (), CostKind);
1915+ InstructionCost ScalarizedCost = 0 ;
1916+
1917+ DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1918+ auto FailureGuard = make_scope_exit ([&]() {
1919+ // If the transform is aborted, discard the ScalarizationResults.
1920+ for (auto &Pair : NeedFreeze)
1921+ Pair.second .discard ();
1922+ });
1923+
1924+ for (User *U : LI->users ()) {
1925+ auto *UI = cast<ExtractElementInst>(U);
19021926
19031927 auto ScalarIdx =
19041928 canScalarizeAccess (VecTy, UI->getIndexOperand (), LI, AC, DT);
@@ -1920,7 +1944,7 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19201944 nullptr , nullptr , CostKind);
19211945 }
19221946
1923- LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << I
1947+ LLVM_DEBUG (dbgs () << " Found all extractions of a vector load: " << *LI
19241948 << " \n LoadExtractCost: " << OriginalCost
19251949 << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
19261950
@@ -1966,6 +1990,70 @@ bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
19661990 return true ;
19671991}
19681992
1993+ // / Try to scalarize vector loads feeding bitcast instructions.
1994+ bool VectorCombine::scalarizeLoadBitcast (LoadInst *LI, VectorType *VecTy,
1995+ Value *Ptr) {
1996+ InstructionCost OriginalCost =
1997+ TTI.getMemoryOpCost (Instruction::Load, VecTy, LI->getAlign (),
1998+ LI->getPointerAddressSpace (), CostKind);
1999+
2000+ Type *TargetScalarType = nullptr ;
2001+ unsigned VecBitWidth = DL->getTypeSizeInBits (VecTy);
2002+
2003+ for (User *U : LI->users ()) {
2004+ auto *BC = cast<BitCastInst>(U);
2005+
2006+ Type *DestTy = BC->getDestTy ();
2007+ if (!DestTy->isIntegerTy () && !DestTy->isFloatingPointTy ())
2008+ return false ;
2009+
2010+ unsigned DestBitWidth = DL->getTypeSizeInBits (DestTy);
2011+ if (DestBitWidth != VecBitWidth)
2012+ return false ;
2013+
2014+ // All bitcasts should target the same scalar type.
2015+ if (!TargetScalarType)
2016+ TargetScalarType = DestTy;
2017+ else if (TargetScalarType != DestTy)
2018+ return false ;
2019+
2020+ OriginalCost +=
2021+ TTI.getCastInstrCost (Instruction::BitCast, TargetScalarType, VecTy,
2022+ TTI.getCastContextHint (BC), CostKind, BC);
2023+ }
2024+
2025+ if (!TargetScalarType || LI->user_empty ())
2026+ return false ;
2027+ InstructionCost ScalarizedCost =
2028+ TTI.getMemoryOpCost (Instruction::Load, TargetScalarType, LI->getAlign (),
2029+ LI->getPointerAddressSpace (), CostKind);
2030+
2031+ LLVM_DEBUG (dbgs () << " Found vector load feeding only bitcasts: " << *LI
2032+ << " \n OriginalCost: " << OriginalCost
2033+ << " vs ScalarizedCost: " << ScalarizedCost << " \n " );
2034+
2035+ if (ScalarizedCost >= OriginalCost)
2036+ return false ;
2037+
2038+ // Ensure we add the load back to the worklist BEFORE its users so they can
2039+ // erased in the correct order.
2040+ Worklist.push (LI);
2041+
2042+ Builder.SetInsertPoint (LI);
2043+ auto *ScalarLoad =
2044+ Builder.CreateLoad (TargetScalarType, Ptr, LI->getName () + " .scalar" );
2045+ ScalarLoad->setAlignment (LI->getAlign ());
2046+ ScalarLoad->copyMetadata (*LI);
2047+
2048+ // Replace all bitcast users with the scalar load.
2049+ for (User *U : LI->users ()) {
2050+ auto *BC = cast<BitCastInst>(U);
2051+ replaceValue (*BC, *ScalarLoad, false );
2052+ }
2053+
2054+ return true ;
2055+ }
2056+
19692057bool VectorCombine::scalarizeExtExtract (Instruction &I) {
19702058 if (!TTI.allowVectorElementIndexingUsingGEP ())
19712059 return false ;
@@ -4555,7 +4643,7 @@ bool VectorCombine::run() {
45554643 if (IsVectorType) {
45564644 if (scalarizeOpOrCmp (I))
45574645 return true ;
4558- if (scalarizeLoadExtract (I))
4646+ if (scalarizeLoad (I))
45594647 return true ;
45604648 if (scalarizeExtExtract (I))
45614649 return true ;
0 commit comments