Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions llvm/include/llvm/Analysis/LoopAccessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,25 @@ class MemoryDepChecker {
void mergeInStatus(VectorizationSafetyStatus S);

struct DepDistanceStrideAndSizeInfo {
// Strides could either be scaled (in bytes, taking the size of the
// underlying type into account), or unscaled (in indexing units; unscaled
// stride = scaled stride / size of underlying type). Here, strides are
// unscaled.
const SCEV *Dist;
uint64_t StrideA;
uint64_t StrideB;
uint64_t MaxStride;
std::optional<uint64_t> CommonStride;
bool ShouldRetryWithRuntimeCheck;
uint64_t TypeByteSize;
bool AIsWrite;
bool BIsWrite;

DepDistanceStrideAndSizeInfo(const SCEV *Dist, uint64_t StrideA,
uint64_t StrideB, uint64_t TypeByteSize,
bool AIsWrite, bool BIsWrite)
: Dist(Dist), StrideA(StrideA), StrideB(StrideB),
DepDistanceStrideAndSizeInfo(const SCEV *Dist, uint64_t MaxStride,
std::optional<uint64_t> CommonStride,
bool ShouldRetryWithRuntimeCheck,
uint64_t TypeByteSize, bool AIsWrite,
bool BIsWrite)
: Dist(Dist), MaxStride(MaxStride), CommonStride(CommonStride),
ShouldRetryWithRuntimeCheck(ShouldRetryWithRuntimeCheck),
TypeByteSize(TypeByteSize), AIsWrite(AIsWrite), BIsWrite(BIsWrite) {}
};

Expand Down
37 changes: 25 additions & 12 deletions llvm/lib/Analysis/LoopAccessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1994,8 +1994,23 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
DL.getTypeStoreSizeInBits(ATy) == DL.getTypeStoreSizeInBits(BTy);
if (!HasSameSize)
TypeByteSize = 0;
return DepDistanceStrideAndSizeInfo(Dist, std::abs(StrideAPtrInt),
std::abs(StrideBPtrInt), TypeByteSize,

StrideAPtrInt = std::abs(StrideAPtrInt);
StrideBPtrInt = std::abs(StrideBPtrInt);

uint64_t MaxStride = std::max(StrideAPtrInt, StrideBPtrInt);

std::optional<uint64_t> CommonStride;
if (StrideAPtrInt == StrideBPtrInt)
CommonStride = StrideAPtrInt;

// TODO: Historically, we don't retry with runtime checks unless the
// (unscaled) strides are the same. Fix this once the condition for runtime
// checks in isDependent is fixed.
bool ShouldRetryWithRuntimeCheck = CommonStride.has_value();

return DepDistanceStrideAndSizeInfo(Dist, MaxStride, CommonStride,
ShouldRetryWithRuntimeCheck, TypeByteSize,
AIsWrite, BIsWrite);
}

Expand All @@ -2011,23 +2026,21 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
if (std::holds_alternative<Dependence::DepType>(Res))
return std::get<Dependence::DepType>(Res);

auto &[Dist, StrideA, StrideB, TypeByteSize, AIsWrite, BIsWrite] =
auto &[Dist, MaxStride, CommonStride, ShouldRetryWithRuntimeCheck,
TypeByteSize, AIsWrite, BIsWrite] =
std::get<DepDistanceStrideAndSizeInfo>(Res);
bool HasSameSize = TypeByteSize > 0;

std::optional<uint64_t> CommonStride =
StrideA == StrideB ? std::make_optional(StrideA) : std::nullopt;
if (isa<SCEVCouldNotCompute>(Dist)) {
// TODO: Relax requirement that there is a common stride to retry with
// non-constant distance dependencies.
FoundNonConstantDistanceDependence |= CommonStride.has_value();
// TODO: Relax requirement that there is a common unscaled stride to retry
// with non-constant distance dependencies.
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
LLVM_DEBUG(dbgs() << "LAA: Dependence because of uncomputable distance.\n");
return Dependence::Unknown;
}

ScalarEvolution &SE = *PSE.getSE();
auto &DL = InnermostLoop->getHeader()->getDataLayout();
uint64_t MaxStride = std::max(StrideA, StrideB);

// If the distance between the acecsses is larger than their maximum absolute
// stride multiplied by the symbolic maximum backedge taken count (which is an
Expand Down Expand Up @@ -2086,7 +2099,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
// condition to consider retrying with runtime checks. Historically, we
// did not set it when strides were different but there is no inherent
// reason to.
FoundNonConstantDistanceDependence |= CommonStride.has_value();
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
return Dependence::Unknown;
}
if (!HasSameSize ||
Expand All @@ -2105,7 +2118,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
int64_t MinDistance = SE.getSignedRangeMin(Dist).getSExtValue();
// Below we only handle strictly positive distances.
if (MinDistance <= 0) {
FoundNonConstantDistanceDependence |= CommonStride.has_value();
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
return Dependence::Unknown;
}

Expand All @@ -2118,7 +2131,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
// condition to consider retrying with runtime checks. Historically, we
// did not set it when strides were different but there is no inherent
// reason to.
FoundNonConstantDistanceDependence |= CommonStride.has_value();
FoundNonConstantDistanceDependence |= ShouldRetryWithRuntimeCheck;
}

if (!HasSameSize) {
Expand Down
Loading