@@ -1795,8 +1795,7 @@ void MemoryDepChecker::mergeInStatus(VectorizationSafetyStatus S) {
17951795// / }
17961796static bool isSafeDependenceDistance (const DataLayout &DL, ScalarEvolution &SE,
17971797 const SCEV &MaxBTC, const SCEV &Dist,
1798- uint64_t MaxStride,
1799- uint64_t TypeByteSize) {
1798+ uint64_t MaxStride) {
18001799
18011800 // If we can prove that
18021801 // (**) |Dist| > MaxBTC * Step
@@ -1815,8 +1814,7 @@ static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE,
18151814 // will be executed only if LoopCount >= VF, proving distance >= LoopCount
18161815 // also guarantees that distance >= VF.
18171816 //
1818- const uint64_t ByteStride = MaxStride * TypeByteSize;
1819- const SCEV *Step = SE.getConstant (MaxBTC.getType (), ByteStride);
1817+ const SCEV *Step = SE.getConstant (MaxBTC.getType (), MaxStride);
18201818 const SCEV *Product = SE.getMulExpr (&MaxBTC, Step);
18211819
18221820 const SCEV *CastedDist = &Dist;
@@ -1860,25 +1858,23 @@ static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride,
18601858 if (Distance % TypeByteSize)
18611859 return false ;
18621860
1863- uint64_t ScaledDist = Distance / TypeByteSize;
1864-
1865- // No dependence if the scaled distance is not multiple of the stride.
1861+ // No dependence if the distance is not multiple of the stride.
18661862 // E.g.
18671863 // for (i = 0; i < 1024 ; i += 4)
18681864 // A[i+2] = A[i] + 1;
18691865 //
1870- // Two accesses in memory (scaled distance is 2, stride is 4):
1866+ // Two accesses in memory (distance is 2, stride is 4):
18711867 // | A[0] | | | | A[4] | | | |
18721868 // | | | A[2] | | | | A[6] | |
18731869 //
18741870 // E.g.
18751871 // for (i = 0; i < 1024 ; i += 3)
18761872 // A[i+4] = A[i] + 1;
18771873 //
1878- // Two accesses in memory (scaled distance is 4, stride is 3):
1874+ // Two accesses in memory (distance is 4, stride is 3):
18791875 // | A[0] | | | A[3] | | | A[6] | | |
18801876 // | | | | | A[4] | | | A[7] | |
1881- return ScaledDist % Stride;
1877+ return Distance % Stride;
18821878}
18831879
18841880std::variant<MemoryDepChecker::Dependence::DepType,
@@ -1987,25 +1983,27 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
19871983 return MemoryDepChecker::Dependence::Unknown;
19881984 }
19891985
1990- uint64_t TypeByteSize = DL.getTypeAllocSize (ATy);
1991- bool HasSameSize =
1992- DL.getTypeStoreSizeInBits (ATy) == DL.getTypeStoreSizeInBits (BTy);
1993- if (!HasSameSize)
1994- TypeByteSize = 0 ;
1986+ TypeSize AStoreSz = DL.getTypeStoreSize (ATy),
1987+ BStoreSz = DL.getTypeStoreSize (BTy);
1988+
1989+ // If store sizes are not the same, set TypeByteSize to zero, so we can check
1990+ // it in the caller.
1991+ uint64_t ASz = DL.getTypeAllocSize (ATy), BSz = DL.getTypeAllocSize (BTy),
1992+ TypeByteSize = AStoreSz == BStoreSz ? BSz : 0 ;
19951993
1996- StrideAPtrInt = std::abs (StrideAPtrInt);
1997- StrideBPtrInt = std::abs (StrideBPtrInt);
1994+ uint64_t StrideAScaled = std::abs (StrideAPtrInt) * ASz ;
1995+ uint64_t StrideBScaled = std::abs (StrideBPtrInt) * BSz ;
19981996
1999- uint64_t MaxStride = std::max (StrideAPtrInt, StrideBPtrInt );
1997+ uint64_t MaxStride = std::max (StrideAScaled, StrideBScaled );
20001998
20011999 std::optional<uint64_t > CommonStride;
2002- if (StrideAPtrInt == StrideBPtrInt )
2003- CommonStride = StrideAPtrInt ;
2000+ if (StrideAScaled == StrideBScaled )
2001+ CommonStride = StrideAScaled ;
20042002
20052003 // TODO: Historically, we don't retry with runtime checks unless the
20062004 // (unscaled) strides are the same. Fix this once the condition for runtime
20072005 // checks in isDependent is fixed.
2008- bool ShouldRetryWithRuntimeCheck = CommonStride. has_value () ;
2006+ bool ShouldRetryWithRuntimeCheck = StrideAPtrInt == StrideBPtrInt ;
20092007
20102008 return DepDistanceStrideAndSizeInfo (Dist, MaxStride, CommonStride,
20112009 ShouldRetryWithRuntimeCheck, TypeByteSize,
@@ -2045,9 +2043,9 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
20452043 // upper bound of the number of iterations), the accesses are independet, i.e.
20462044 // they are far enough appart that accesses won't access the same location
20472045 // across all loop ierations.
2048- if (HasSameSize && isSafeDependenceDistance (
2049- DL, SE, *(PSE. getSymbolicMaxBackedgeTakenCount ()),
2050- *Dist, MaxStride, TypeByteSize ))
2046+ if (HasSameSize &&
2047+ isSafeDependenceDistance (
2048+ DL, SE, *(PSE. getSymbolicMaxBackedgeTakenCount ()), *Dist, MaxStride))
20512049 return Dependence::NoDep;
20522050
20532051 const SCEVConstant *ConstDist = dyn_cast<SCEVConstant>(Dist);
@@ -2151,8 +2149,8 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
21512149
21522150 // It's not vectorizable if the distance is smaller than the minimum distance
21532151 // needed for a vectroized/unrolled version. Vectorizing one iteration in
2154- // front needs TypeByteSize * Stride . Vectorizing the last iteration needs
2155- // TypeByteSize (No need to plus the last gap distance).
2152+ // front needs CommonStride . Vectorizing the last iteration needs TypeByteSize
2153+ // (No need to plus the last gap distance).
21562154 //
21572155 // E.g. Assume one char is 1 byte in memory and one int is 4 bytes.
21582156 // foo(int *A) {
@@ -2179,8 +2177,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
21792177 // We know that Dist is positive, but it may not be constant. Use the signed
21802178 // minimum for computations below, as this ensures we compute the closest
21812179 // possible dependence distance.
2182- uint64_t MinDistanceNeeded =
2183- TypeByteSize * *CommonStride * (MinNumIter - 1 ) + TypeByteSize;
2180+ uint64_t MinDistanceNeeded = *CommonStride * (MinNumIter - 1 ) + TypeByteSize;
21842181 if (MinDistanceNeeded > static_cast <uint64_t >(MinDistance)) {
21852182 if (!ConstDist) {
21862183 // For non-constant distances, we checked the lower bound of the
@@ -2236,7 +2233,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
22362233
22372234 // An update to MinDepDistBytes requires an update to MaxSafeVectorWidthInBits
22382235 // since there is a backwards dependency.
2239- uint64_t MaxVF = MinDepDistBytes / (TypeByteSize * * CommonStride) ;
2236+ uint64_t MaxVF = MinDepDistBytes / * CommonStride;
22402237 LLVM_DEBUG (dbgs () << " LAA: Positive min distance " << MinDistance
22412238 << " with max VF = " << MaxVF << ' \n ' );
22422239
0 commit comments