Skip to content

Commit 58134c4

Browse files
committed
Address comments
1 parent f346abe commit 58134c4

File tree

3 files changed

+78
-67
lines changed

3 files changed

+78
-67
lines changed

llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ static const unsigned UnknownAddressSpace =
113113
std::numeric_limits<unsigned>::max();
114114

115115
DEBUG_COUNTER(StraightLineStrengthReduceCounter, "slsr-counter",
116-
"Controls whether rewriteCandidateWithBasis is executed.");
116+
"Controls whether rewriteCandidate is executed.");
117117

118118
namespace {
119119

@@ -323,6 +323,24 @@ class StraightLineStrengthReduce {
323323
bool isHighEfficiency() const {
324324
return getComputationEfficiency(CandidateKind, Index, Stride, Base) >= 4;
325325
}
326+
327+
// Verify that this candidate has valid delta components relative to the
328+
// basis
329+
bool hasValidDelta(const Candidate &Basis) const {
330+
switch (DeltaKind) {
331+
case IndexDelta:
332+
// Index differs, Base and Stride must match
333+
return Base == Basis.Base && StrideSCEV == Basis.StrideSCEV;
334+
case StrideDelta:
335+
// Stride differs, Base and Index must match
336+
return Base == Basis.Base && Index == Basis.Index;
337+
case BaseDelta:
338+
// Base differs, Stride and Index must match
339+
return StrideSCEV == Basis.StrideSCEV && Index == Basis.Index;
340+
default:
341+
return false;
342+
}
343+
}
326344
};
327345

328346
bool runOnFunction(Function &F);
@@ -363,7 +381,7 @@ class StraightLineStrengthReduce {
363381
Instruction *I);
364382

365383
// Rewrites candidate C with respect to Basis.
366-
void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis);
384+
void rewriteCandidate(const Candidate &C);
367385

368386
// Emit code that computes the "bump" from Basis to C.
369387
static Value *emitBump(const Candidate &Basis, const Candidate &C,
@@ -540,9 +558,8 @@ class StraightLineStrengthReduce {
540558
};
541559
};
542560

543-
inline llvm::raw_ostream &
544-
operator<<(llvm::raw_ostream &OS,
545-
const StraightLineStrengthReduce::Candidate &C) {
561+
inline raw_ostream &operator<<(raw_ostream &OS,
562+
const StraightLineStrengthReduce::Candidate &C) {
546563
OS << "Ins: " << *C.Ins << "\n Base: " << *C.Base
547564
<< "\n Index: " << *C.Index << "\n Stride: " << *C.Stride
548565
<< "\n StrideSCEV: " << *C.StrideSCEV;
@@ -551,10 +568,9 @@ operator<<(llvm::raw_ostream &OS,
551568
return OS;
552569
}
553570

554-
LLVM_ATTRIBUTE_UNUSED
555-
inline llvm::raw_ostream &
556-
operator<<(llvm::raw_ostream &OS,
557-
const StraightLineStrengthReduce::DeltaInfo &DI) {
571+
LLVM_DUMP_METHOD
572+
inline raw_ostream &
573+
operator<<(raw_ostream &OS, const StraightLineStrengthReduce::DeltaInfo &DI) {
558574
OS << "Cand: " << *DI.Cand << "\n";
559575
OS << "Delta Kind: ";
560576
switch (DI.DeltaKind) {
@@ -730,9 +746,14 @@ void StraightLineStrengthReduce::setBasisAndDeltaFor(Candidate &C) {
730746

731747
// If we did not find a constant delta, we might have found a variable delta
732748
if (C.Delta) {
733-
LLVM_DEBUG(dbgs() << "Found delta from ";
734-
if (C.DeltaKind == Candidate::BaseDelta) dbgs() << "Base: ";
735-
else dbgs() << "Stride: "; dbgs() << *C.Delta << "\n");
749+
LLVM_DEBUG({
750+
dbgs() << "Found delta from ";
751+
if (C.DeltaKind == Candidate::BaseDelta)
752+
dbgs() << "Base: ";
753+
else
754+
dbgs() << "Stride: ";
755+
dbgs() << *C.Delta << "\n";
756+
});
736757
assert(C.DeltaKind != Candidate::InvalidDelta && C.Basis);
737758
}
738759
}
@@ -816,8 +837,7 @@ void StraightLineStrengthReduce::sortCandidateInstructions() {
816837
// processed before processing itself.
817838
DenseMap<Instruction *, int> InDegree;
818839
for (auto &KV : DependencyGraph) {
819-
if (InDegree.find(KV.first) == InDegree.end())
820-
InDegree[KV.first] = 0;
840+
InDegree.try_emplace(KV.first, 0);
821841

822842
for (auto *Child : KV.second) {
823843
InDegree[Child]++;
@@ -839,8 +859,8 @@ void StraightLineStrengthReduce::sortCandidateInstructions() {
839859
SortedCandidateInsts.push_back(I);
840860

841861
for (auto *Next : DependencyGraph[I]) {
842-
InDegree[Next]--;
843-
if (InDegree[Next] == 0)
862+
auto &Degree = InDegree[Next];
863+
if (--Degree == 0)
844864
WorkList.push(Next);
845865
}
846866
}
@@ -1080,8 +1100,8 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
10801100
IRBuilder<> &Builder,
10811101
const DataLayout *DL) {
10821102
auto CreateMul = [&](Value *LHS, Value *RHS) {
1083-
if (isa<ConstantInt>(RHS)) {
1084-
APInt ConstRHS = cast<ConstantInt>(RHS)->getValue();
1103+
if (ConstantInt *CR = dyn_cast<ConstantInt>(RHS)) {
1104+
const APInt &ConstRHS = CR->getValue();
10851105
IntegerType *DeltaType =
10861106
IntegerType::get(C.Ins->getContext(), ConstRHS.getBitWidth());
10871107
if (ConstRHS.isPowerOf2()) {
@@ -1126,58 +1146,51 @@ Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis,
11261146
Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType);
11271147

11281148
return CreateMul(ExtendedStride, C.Delta);
1129-
} else {
1130-
assert(C.DeltaKind == Candidate::StrideDelta ||
1131-
C.DeltaKind == Candidate::BaseDelta);
1132-
assert(C.CandidateKind != Candidate::Mul);
1133-
// StrideDelta
1134-
// X = B + i * S
1135-
// Y = B + i * S'
1136-
// = B + i * (S + Delta)
1137-
// = B + i * S + i * Delta
1138-
// = X + i * StrideDelta
1139-
// Bump = i * (S' - S)
1140-
//
1141-
// BaseDelta
1142-
// X = B + i * S
1143-
// Y = B' + i * S
1144-
// = (B + Delta) + i * S
1145-
// = X + BaseDelta
1146-
// Bump = (B' - B).
1147-
Value *Bump = C.Delta;
1148-
if (C.DeltaKind == Candidate::StrideDelta) {
1149-
// If this value is consumed by a GEP, promote StrideDelta before doing
1150-
// StrideDelta * Index to ensure the same semantics as the original GEP.
1151-
if (C.CandidateKind == Candidate::GEP) {
1152-
auto *GEP = cast<GetElementPtrInst>(C.Ins);
1153-
Type *NewScalarIndexTy =
1154-
DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
1155-
Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
1156-
}
1157-
if (!C.Index->isOne()) {
1158-
Value *ExtendedIndex =
1159-
Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
1160-
Bump = CreateMul(Bump, ExtendedIndex);
1161-
}
1149+
}
1150+
1151+
assert(C.DeltaKind == Candidate::StrideDelta ||
1152+
C.DeltaKind == Candidate::BaseDelta);
1153+
assert(C.CandidateKind != Candidate::Mul);
1154+
// StrideDelta
1155+
// X = B + i * S
1156+
// Y = B + i * S'
1157+
// = B + i * (S + Delta)
1158+
// = B + i * S + i * Delta
1159+
// = X + i * StrideDelta
1160+
// Bump = i * (S' - S)
1161+
//
1162+
// BaseDelta
1163+
// X = B + i * S
1164+
// Y = B' + i * S
1165+
// = (B + Delta) + i * S
1166+
// = X + BaseDelta
1167+
// Bump = (B' - B).
1168+
Value *Bump = C.Delta;
1169+
if (C.DeltaKind == Candidate::StrideDelta) {
1170+
// If this value is consumed by a GEP, promote StrideDelta before doing
1171+
// StrideDelta * Index to ensure the same semantics as the original GEP.
1172+
if (C.CandidateKind == Candidate::GEP) {
1173+
auto *GEP = cast<GetElementPtrInst>(C.Ins);
1174+
Type *NewScalarIndexTy =
1175+
DL->getIndexType(GEP->getPointerOperandType()->getScalarType());
1176+
Bump = Builder.CreateSExtOrTrunc(Bump, NewScalarIndexTy);
1177+
}
1178+
if (!C.Index->isOne()) {
1179+
Value *ExtendedIndex =
1180+
Builder.CreateSExtOrTrunc(C.Index, Bump->getType());
1181+
Bump = CreateMul(Bump, ExtendedIndex);
11621182
}
1163-
return Bump;
11641183
}
1184+
return Bump;
11651185
}
11661186

1167-
void StraightLineStrengthReduce::rewriteCandidateWithBasis(
1168-
const Candidate &C, const Candidate &Basis) {
1187+
void StraightLineStrengthReduce::rewriteCandidate(const Candidate &C) {
11691188
if (!DebugCounter::shouldExecute(StraightLineStrengthReduceCounter))
11701189
return;
11711190

1172-
// If one of Base, Index, and Stride are different,
1173-
// other parts must be the same
1191+
const Candidate &Basis = *C.Basis;
11741192
assert(C.Delta && C.CandidateKind == Basis.CandidateKind &&
1175-
((C.Base == Basis.Base && C.StrideSCEV == Basis.StrideSCEV &&
1176-
C.DeltaKind == Candidate::IndexDelta) ||
1177-
(C.Base == Basis.Base && C.Index == Basis.Index &&
1178-
C.DeltaKind == Candidate::StrideDelta) ||
1179-
(C.StrideSCEV == Basis.StrideSCEV && C.Index == Basis.Index &&
1180-
C.DeltaKind == Candidate::BaseDelta)));
1193+
C.hasValidDelta(Basis));
11811194

11821195
IRBuilder<> Builder(C.Ins);
11831196
Value *Bump = emitBump(Basis, C, Builder, DL);
@@ -1258,7 +1271,7 @@ bool StraightLineStrengthReduce::runOnFunction(Function &F) {
12581271
// always before rewriting its Basis
12591272
for (Instruction *I : reverse(SortedCandidateInsts))
12601273
if (Candidate *C = pickRewriteCandidate(I))
1261-
rewriteCandidateWithBasis(*C, *C->Basis);
1274+
rewriteCandidate(*C);
12621275

12631276
for (auto *DeadIns : DeadInstructions)
12641277
// A dead instruction may be another dead instruction's op,

llvm/test/Transforms/StraightLineStrengthReduce/NVPTX/slsr-i8-gep.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
; RUN: opt < %s -passes=slsr -S | FileCheck %s
22

3-
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
4-
target triple = "nvptx64-unknown-unknown"
3+
target triple = "nvptx64-nvidia-cuda"
54

65
; CHECK-LABEL: slsr_i8_zero_delta(
76
; CHECK-SAME: ptr [[IN:%.*]], ptr [[OUT:%.*]], i64 [[ADD:%.*]])

llvm/test/Transforms/StraightLineStrengthReduce/NVPTX/slsr-var-delta.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
; RUN: opt < %s -passes=slsr -S | FileCheck %s
22
; RUN: llc < %s -march=nvptx64 -mcpu=sm_75 | FileCheck %s --check-prefix=PTX
33

4-
target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
5-
target triple = "nvptx64-unknown-unknown"
4+
target triple = "nvptx64-nvidia-cuda"
65

76
; Test SLSR can reuse the computation by complex variable delta.
87
; The original program needs 4 mul.wide.s32, after SLSR with

0 commit comments

Comments
 (0)