Skip to content

Commit 3e38614

Browse files
committed
[CHERIOT] Add a TTI callback to disallow LSR base register expressions.
This allows us to turn on SCEV for cheriot, with the caveat that we need to disallow LSR base register expressions that are negatively indexed, since they may not be representable on cheriot.
1 parent bfb9e86 commit 3e38614

File tree

7 files changed

+39
-6
lines changed

7 files changed

+39
-6
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,8 @@ class TargetTransformInfo {
787787
AddressingModeKind getPreferredAddressingMode(const Loop *L,
788788
ScalarEvolution *SE) const;
789789

790+
bool isLegalBaseRegForLSR(const SCEV *) const;
791+
790792
/// Return true if the target supports masked store.
791793
bool isLegalMaskedStore(Type *DataType, Align Alignment) const;
792794
/// Return true if the target supports masked load.
@@ -1996,6 +1998,7 @@ class TargetTransformInfo::Concept {
19961998
TargetLibraryInfo *LibInfo) = 0;
19971999
virtual AddressingModeKind
19982000
getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const = 0;
2001+
virtual bool isLegalBaseRegForLSR(const SCEV *) const = 0;
19992002
virtual bool isLegalMaskedStore(Type *DataType, Align Alignment) = 0;
20002003
virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment) = 0;
20012004
virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0;
@@ -2534,6 +2537,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
25342537
ScalarEvolution *SE) const override {
25352538
return Impl.getPreferredAddressingMode(L, SE);
25362539
}
2540+
bool isLegalBaseRegForLSR(const SCEV *S) const override {
2541+
return Impl.isLegalBaseRegForLSR(S);
2542+
}
25372543
bool isLegalMaskedStore(Type *DataType, Align Alignment) override {
25382544
return Impl.isLegalMaskedStore(DataType, Alignment);
25392545
}

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ class TargetTransformInfoImplBase {
276276
return TTI::AMK_None;
277277
}
278278

279+
bool isLegalBaseRegForLSR(const SCEV *S) const { return true; }
280+
279281
bool isLegalMaskedStore(Type *DataType, Align Alignment) const {
280282
return false;
281283
}

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6296,12 +6296,6 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
62966296
assert(GEP->getSourceElementType()->isSized() &&
62976297
"GEP source element type must be sized");
62986298

6299-
const DataLayout &DL = F.getParent()->getDataLayout();
6300-
// FIXME: Ideally, we should teach Scalar Evolution to
6301-
// understand fat pointers.
6302-
if (DL.isFatPointer(GEP->getPointerOperandType()->getPointerAddressSpace()))
6303-
return getUnknown(GEP);
6304-
63056299
SmallVector<const SCEV *, 4> IndexExprs;
63066300
for (Value *Index : GEP->indices())
63076301
IndexExprs.push_back(getSCEV(Index));

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L,
463463
return TTIImpl->getPreferredAddressingMode(L, SE);
464464
}
465465

466+
bool TargetTransformInfo::isLegalBaseRegForLSR(const SCEV *S) const {
467+
return TTIImpl->isLegalBaseRegForLSR(S);
468+
}
469+
466470
bool TargetTransformInfo::isLegalMaskedStore(Type *DataType,
467471
Align Alignment) const {
468472
return TTIImpl->isLegalMaskedStore(DataType, Alignment);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "RISCVTargetTransformInfo.h"
1010
#include "MCTargetDesc/RISCVMatInt.h"
1111
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/Analysis/ScalarEvolution.h"
1213
#include "llvm/Analysis/TargetTransformInfo.h"
1314
#include "llvm/CodeGen/BasicTTIImpl.h"
1415
#include "llvm/CodeGen/CostTable.h"
@@ -2401,6 +2402,26 @@ RISCVTTIImpl::getPreferredAddressingMode(const Loop *L,
24012402
return BasicTTIImplBase::getPreferredAddressingMode(L, SE);
24022403
}
24032404

2405+
bool RISCVTTIImpl::isLegalBaseRegForLSR(const SCEV *S) const {
2406+
if (ST->hasVendorXCheriot()) {
2407+
// Disallow any add-recurrence SCEV where the base offset is negative.
2408+
// This is needed because CHERIoT can't represent pointers before the
2409+
// beginning of an array.
2410+
if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
2411+
const auto *StartAdd = dyn_cast<SCEVAddExpr>(AddRec->getStart());
2412+
if (StartAdd) {
2413+
const auto *Offset = dyn_cast<SCEVConstant>(StartAdd->getOperand(0));
2414+
if (Offset && Offset->getValue()->isNegative())
2415+
return false;
2416+
Offset = dyn_cast<SCEVConstant>(StartAdd->getOperand(1));
2417+
if (Offset && Offset->getValue()->isNegative())
2418+
return false;
2419+
}
2420+
}
2421+
}
2422+
return BasicTTIImplBase::isLegalBaseRegForLSR(S);
2423+
}
2424+
24042425
bool RISCVTTIImpl::isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
24052426
const TargetTransformInfo::LSRCost &C2) {
24062427
// RISC-V specific here are "instruction number 1st priority".

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
394394
TTI::AddressingModeKind getPreferredAddressingMode(const Loop *L,
395395
ScalarEvolution *SE) const;
396396

397+
bool isLegalBaseRegForLSR(const SCEV *S) const;
398+
397399
unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
398400
if (Vector)
399401
return RISCVRegisterClass::VRRC;

llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,10 @@ void Cost::RateFormula(const Formula &F,
15161516
return;
15171517
}
15181518
for (const SCEV *BaseReg : F.BaseRegs) {
1519+
if (!TTI->isLegalBaseRegForLSR(BaseReg)) {
1520+
Lose();
1521+
return;
1522+
}
15191523
if (VisitedRegs.count(BaseReg)) {
15201524
Lose();
15211525
return;

0 commit comments

Comments
 (0)