Skip to content

Commit 5f62683

Browse files
Add more foldings and support in CVP
Signed-off-by: Vladimir Radosavljevic <[email protected]>
1 parent 05c9e52 commit 5f62683

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,74 @@ static std::optional<Instruction *> foldSignExtendToConst(InstCombiner &IC,
4949

5050
static std::optional<Instruction *> instCombineSignExtend(InstCombiner &IC,
5151
IntrinsicInst &II) {
52+
constexpr unsigned BitWidth = 256;
53+
if (!II.getType()->isIntegerTy(BitWidth))
54+
return std::nullopt;
55+
5256
// Fold signextend(b, signextend(b, x)) -> signextend(b, x)
5357
Value *B = nullptr, *X = nullptr;
5458
if (match(&II, m_Intrinsic<Intrinsic::evm_signextend>(
5559
m_Value(B), m_Intrinsic<Intrinsic::evm_signextend>(
5660
m_Deferred(B), m_Value(X)))))
5761
return IC.replaceInstUsesWith(II, II.getArgOperand(1));
5862

63+
// From now on, we only handle signextend with constant byte index.
64+
const auto *ByteIdxC = dyn_cast<ConstantInt>(II.getArgOperand(0));
65+
if (!ByteIdxC)
66+
return std::nullopt;
67+
68+
// ByteIdx must be in range [0, 31].
69+
uint64_t ByteIdx = ByteIdxC->getZExtValue();
70+
if (ByteIdx >= BitWidth / 8)
71+
return std::nullopt;
72+
73+
unsigned Width = (ByteIdx + 1) * 8;
74+
75+
// Fold signextend into shifts, if second operand or the result is shift
76+
// with constant. LLVM will fuse those shifts, and will replace signextend
77+
// with a shift, which is cheaper.
78+
if (match(II.getArgOperand(1),
79+
m_OneUse(m_Shift(m_Value(), m_ConstantInt()))) ||
80+
(II.hasOneUse() &&
81+
match(II.user_back(), m_Shift(m_Specific(&II), m_ConstantInt())))) {
82+
unsigned ShiftAmt = BitWidth - Width;
83+
auto *Shl = IC.Builder.CreateShl(II.getArgOperand(1), ShiftAmt);
84+
auto *Ashr = IC.Builder.CreateAShr(Shl, ShiftAmt);
85+
return IC.replaceInstUsesWith(II, Ashr);
86+
}
87+
88+
const APInt *AndVal = nullptr;
89+
// Match signextend(b, and(x, C))
90+
if (match(II.getArgOperand(1), m_And(m_Value(X), m_APInt(AndVal)))) {
91+
APInt LowMask = APInt::getLowBitsSet(BitWidth, Width);
92+
93+
// signextend(b, x & C) -> signextend(b, x)
94+
// If and fully preservs low bits, we can drop it.
95+
if ((*AndVal & LowMask) == LowMask)
96+
return IC.replaceOperand(II, 1, X);
97+
98+
// signextend(b, x & C) -> (x & C)
99+
// If and doesn't touch upper bits, and clears sign bit, we can drop
100+
// signextend.
101+
APInt SignBit = APInt(BitWidth, 1).shl(Width - 1);
102+
if ((*AndVal & ~LowMask).isZero() && (*AndVal & SignBit).isZero())
103+
return IC.replaceInstUsesWith(II, II.getArgOperand(1));
104+
105+
// signextend(b, x & C) -> 0
106+
// If and clears all low bits, result is always 0.
107+
if ((*AndVal & LowMask).isZero())
108+
return IC.replaceInstUsesWith(II,
109+
ConstantInt::getNullValue(II.getType()));
110+
}
111+
112+
// and(signextend(b, x), C) -> and(x, C)
113+
// If and doesn't touch upper bits, we can drop signextend.
114+
if (II.hasOneUse() &&
115+
match(II.user_back(), m_And(m_Specific(&II), m_APInt(AndVal)))) {
116+
APInt LowMask = APInt::getLowBitsSet(BitWidth, Width);
117+
if ((*AndVal & ~LowMask).isZero())
118+
return IC.replaceInstUsesWith(II, II.getArgOperand(1));
119+
}
59120
return foldSignExtendToConst(IC, II);
60121
}
61122

llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/IR/Instruction.h"
3333
#include "llvm/IR/Instructions.h"
3434
#include "llvm/IR/IntrinsicInst.h"
35+
#include "llvm/IR/IntrinsicsEVM.h"
3536
#include "llvm/IR/Module.h" // EraVM local
3637
#include "llvm/IR/Operator.h"
3738
#include "llvm/IR/PatternMatch.h"
@@ -654,13 +655,51 @@ static bool processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {
654655
return true;
655656
}
656657

658+
static bool processEVMSignExtend(IntrinsicInst *II, LazyValueInfo *LVI) {
659+
constexpr unsigned BitWidth = 256;
660+
if (!II->getType()->isIntegerTy(BitWidth))
661+
return false;
662+
663+
const auto *ByteIdxC = dyn_cast<ConstantInt>(II->getArgOperand(0));
664+
if (!ByteIdxC)
665+
return false;
666+
667+
// ByteIdx must be in range [0, 31].
668+
uint64_t ByteIdx = ByteIdxC->getZExtValue();
669+
if (ByteIdx >= BitWidth / 8)
670+
return false;
671+
672+
ConstantRange RRange =
673+
LVI->getConstantRangeAtUse(II->getOperandUse(1), /*UndefAllowed*/ false);
674+
if (RRange.isEmptySet())
675+
return false;
676+
677+
// Range that signextend produces is:
678+
// [ -2^(width-1), 2^(width-1)-1 ] in signed space
679+
// Since ConstantRange is [Min, Max), and Max is exclusive, we need to add 1.
680+
unsigned Width = (ByteIdx + 1) * 8;
681+
ConstantRange Range = ConstantRange::getNonEmpty(
682+
APInt::getSignedMinValue(Width).sext(BitWidth),
683+
APInt::getSignedMaxValue(Width).sext(BitWidth) + 1);
684+
685+
if (!Range.contains(RRange))
686+
return false;
687+
688+
II->replaceAllUsesWith(II->getArgOperand(1));
689+
II->eraseFromParent();
690+
return true;
691+
}
692+
657693
/// Infer nonnull attributes for the arguments at the specified callsite.
658694
static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
659695

660696
if (CB.getIntrinsicID() == Intrinsic::abs) {
661697
return processAbsIntrinsic(&cast<IntrinsicInst>(CB), LVI);
662698
}
663699

700+
if (CB.getIntrinsicID() == Intrinsic::evm_signextend)
701+
return processEVMSignExtend(&cast<IntrinsicInst>(CB), LVI);
702+
664703
if (auto *CI = dyn_cast<CmpIntrinsic>(&CB)) {
665704
return processCmpIntrinsic(CI, LVI);
666705
}

0 commit comments

Comments
 (0)