Skip to content

Commit 61e5d8e

Browse files
committed
Add addSafeEltsRuntimeChecks
1 parent 7ca41c8 commit 61e5d8e

File tree

3 files changed

+86
-64
lines changed

3 files changed

+86
-64
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -569,11 +569,14 @@ addRuntimeChecks(Instruction *Loc, Loop *TheLoop,
569569
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
570570
SCEVExpander &Expander, bool HoistRuntimeChecks = false);
571571

572-
LLVM_ABI Value *
573-
addDiffRuntimeChecks(Instruction *Loc, ArrayRef<PointerDiffInfo> Checks,
574-
SCEVExpander &Expander,
575-
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF,
576-
unsigned IC, ElementCount VF, bool UseSafeEltsMask);
572+
LLVM_ABI Value *addSafeEltsRuntimeChecks(Instruction *Loc,
573+
ArrayRef<PointerDiffInfo> Checks,
574+
SCEVExpander &Expander,
575+
ElementCount VF);
576+
577+
LLVM_ABI Value *addDiffRuntimeChecks(
578+
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
579+
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC);
577580

578581
/// Struct to hold information about a partially invariant condition.
579582
struct IVConditionInfo {

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2020,10 +2020,61 @@ Value *llvm::addRuntimeChecks(
20202020
return MemoryRuntimeCheck;
20212021
}
20222022

2023+
Value *llvm::addSafeEltsRuntimeChecks(Instruction *Loc,
2024+
ArrayRef<PointerDiffInfo> Checks,
2025+
SCEVExpander &Expander, ElementCount VF) {
2026+
IRBuilder ChkBuilder(Loc->getContext(),
2027+
InstSimplifyFolder(Loc->getDataLayout()));
2028+
ChkBuilder.SetInsertPoint(Loc);
2029+
Value *MemoryRuntimeCheck = nullptr;
2030+
2031+
// Map to keep track of created compares, The key is the pair of operands for
2032+
// the compare, to allow detecting and re-using redundant compares.
2033+
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2034+
Value *AliasLaneMask = nullptr;
2035+
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
2036+
WriteAfterRead] : Checks) {
2037+
Type *Ty = SinkStart->getType();
2038+
Value *Sink = Expander.expandCodeFor(SinkStart, Ty, Loc);
2039+
Value *Src = Expander.expandCodeFor(SrcStart, Ty, Loc);
2040+
if (SeenCompares.lookup({Sink, Src}))
2041+
continue;
2042+
2043+
unsigned IntOpc = WriteAfterRead ? Intrinsic::loop_dependence_war_mask
2044+
: Intrinsic::loop_dependence_raw_mask;
2045+
Value *SourceAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Src,
2046+
ChkBuilder.getPtrTy());
2047+
Value *SinkAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Sink,
2048+
ChkBuilder.getPtrTy());
2049+
Value *M = ChkBuilder.CreateIntrinsic(
2050+
IntOpc, {VectorType::get(ChkBuilder.getInt1Ty(), VF)},
2051+
{SourceAsPtr, SinkAsPtr, ChkBuilder.getInt64(AccessSize)}, nullptr,
2052+
"alias.lane.mask");
2053+
SeenCompares.insert({{Sink, Src}, M});
2054+
if (AliasLaneMask)
2055+
M = ChkBuilder.CreateAnd(AliasLaneMask, M);
2056+
else
2057+
AliasLaneMask = M;
2058+
}
2059+
assert(AliasLaneMask && "Expected an alias lane mask to have been created.");
2060+
auto *VecVT = VectorType::get(ChkBuilder.getInt1Ty(), VF);
2061+
// Extend to an i8 since i1 is too small to add with
2062+
Value *PopCount = ChkBuilder.CreateCast(
2063+
Instruction::ZExt, AliasLaneMask,
2064+
VectorType::get(ChkBuilder.getInt8Ty(), VecVT->getElementCount()));
2065+
2066+
PopCount =
2067+
ChkBuilder.CreateUnaryIntrinsic(Intrinsic::vector_reduce_add, PopCount);
2068+
PopCount = ChkBuilder.CreateCast(Instruction::ZExt, PopCount,
2069+
ChkBuilder.getInt64Ty());
2070+
MemoryRuntimeCheck = ChkBuilder.CreateICmpUGT(
2071+
PopCount, ConstantInt::get(ChkBuilder.getInt64Ty(), 0));
2072+
return MemoryRuntimeCheck;
2073+
}
2074+
20232075
Value *llvm::addDiffRuntimeChecks(
20242076
Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
2025-
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC,
2026-
ElementCount VF, bool UseSafeEltsMask) {
2077+
function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
20272078

20282079
LLVMContext &Ctx = Loc->getContext();
20292080
IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
@@ -2035,68 +2086,33 @@ Value *llvm::addDiffRuntimeChecks(
20352086
// Map to keep track of created compares, The key is the pair of operands for
20362087
// the compare, to allow detecting and re-using redundant compares.
20372088
DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2038-
Value *AliasLaneMask = nullptr;
20392089
for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze,
20402090
WriteAfterRead] : Checks) {
20412091
Type *Ty = SinkStart->getType();
2042-
if (!VF.isScalar() && UseSafeEltsMask) {
2043-
Value *Sink = Expander.expandCodeFor(SinkStart, Ty, Loc);
2044-
Value *Src = Expander.expandCodeFor(SrcStart, Ty, Loc);
2045-
unsigned IntOpc = WriteAfterRead ? Intrinsic::loop_dependence_war_mask
2046-
: Intrinsic::loop_dependence_raw_mask;
2047-
Value *SourceAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Src,
2048-
ChkBuilder.getPtrTy());
2049-
Value *SinkAsPtr = ChkBuilder.CreateCast(Instruction::IntToPtr, Sink,
2050-
ChkBuilder.getPtrTy());
2051-
Value *M = ChkBuilder.CreateIntrinsic(
2052-
IntOpc, {VectorType::get(ChkBuilder.getInt1Ty(), VF)},
2053-
{SourceAsPtr, SinkAsPtr, ChkBuilder.getInt64(AccessSize)}, nullptr,
2054-
"alias.lane.mask");
2055-
if (AliasLaneMask)
2056-
M = ChkBuilder.CreateAnd(AliasLaneMask, M);
2057-
else
2058-
AliasLaneMask = M;
2059-
} else {
2060-
// Compute VF * IC * AccessSize.
2061-
auto *VFTimesICTimesSize =
2062-
ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2063-
ConstantInt::get(Ty, IC * AccessSize));
2064-
Value *Diff =
2065-
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
2066-
2067-
// Check if the same compare has already been created earlier. In that
2068-
// case, there is no need to check it again.
2069-
Value *IsConflict = SeenCompares.lookup({Diff, VFTimesICTimesSize});
2070-
if (IsConflict)
2071-
continue;
2092+
// Compute VF * IC * AccessSize.
2093+
auto *VFTimesICTimesSize =
2094+
ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2095+
ConstantInt::get(Ty, IC * AccessSize));
2096+
Value *Diff =
2097+
Expander.expandCodeFor(SE.getMinusSCEV(SinkStart, SrcStart), Ty, Loc);
2098+
2099+
// Check if the same compare has already been created earlier. In that
2100+
// case, there is no need to check it again.
2101+
Value *IsConflict = SeenCompares.lookup({Diff, VFTimesICTimesSize});
2102+
if (IsConflict)
2103+
continue;
20722104

2105+
IsConflict =
2106+
ChkBuilder.CreateICmpULT(Diff, VFTimesICTimesSize, "diff.check");
2107+
SeenCompares.insert({{Diff, VFTimesICTimesSize}, IsConflict});
2108+
if (NeedsFreeze)
2109+
IsConflict =
2110+
ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr");
2111+
if (MemoryRuntimeCheck) {
20732112
IsConflict =
2074-
ChkBuilder.CreateICmpULT(Diff, VFTimesICTimesSize, "diff.check");
2075-
SeenCompares.insert({{Diff, VFTimesICTimesSize}, IsConflict});
2076-
if (NeedsFreeze)
2077-
IsConflict =
2078-
ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr");
2079-
if (MemoryRuntimeCheck) {
2080-
IsConflict =
2081-
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
2113+
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
20822114
}
20832115
MemoryRuntimeCheck = IsConflict;
2084-
}
2085-
}
2086-
2087-
if (AliasLaneMask) {
2088-
auto *VecVT = VectorType::get(ChkBuilder.getInt1Ty(), VF);
2089-
// Extend to an i8 since i1 is too small to add with
2090-
Value *PopCount = ChkBuilder.CreateCast(
2091-
Instruction::ZExt, AliasLaneMask,
2092-
VectorType::get(ChkBuilder.getInt8Ty(), VecVT->getElementCount()));
2093-
2094-
PopCount =
2095-
ChkBuilder.CreateUnaryIntrinsic(Intrinsic::vector_reduce_add, PopCount);
2096-
PopCount = ChkBuilder.CreateCast(Instruction::ZExt, PopCount,
2097-
ChkBuilder.getInt64Ty());
2098-
MemoryRuntimeCheck = ChkBuilder.CreateICmpUGT(
2099-
PopCount, ConstantInt::get(ChkBuilder.getInt64Ty(), 0));
21002116
}
21012117

21022118
return MemoryRuntimeCheck;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,7 +1867,10 @@ class GeneratedRTChecks {
18671867
"vector.memcheck");
18681868

18691869
auto DiffChecks = RtPtrChecking.getDiffChecks();
1870-
if (DiffChecks) {
1870+
if (UseSafeEltsMask) {
1871+
MemRuntimeCheckCond = addSafeEltsRuntimeChecks(
1872+
MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp, VF);
1873+
} else if (DiffChecks) {
18711874
Value *RuntimeVF = nullptr;
18721875
MemRuntimeCheckCond = addDiffRuntimeChecks(
18731876
MemCheckBlock->getTerminator(), *DiffChecks, MemCheckExp,
@@ -1876,7 +1879,7 @@ class GeneratedRTChecks {
18761879
RuntimeVF = getRuntimeVF(B, B.getIntNTy(Bits), VF);
18771880
return RuntimeVF;
18781881
},
1879-
IC, VF, UseSafeEltsMask);
1882+
IC);
18801883
} else {
18811884
MemRuntimeCheckCond = addRuntimeChecks(
18821885
MemCheckBlock->getTerminator(), L, RtPtrChecking.getChecks(),

0 commit comments

Comments
 (0)