Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/IR/ConstantRange.h
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ class [[nodiscard]] ConstantRange {
/// Calculate ctpop range.
ConstantRange ctpop() const;

ConstantRange evmSignExtend(const ConstantRange &Other) const;

/// Represents whether an operation on the given constant range is known to
/// always or never overflow.
enum class OverflowResult {
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/IR/IntrinsicsEVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def int_evm_sha3
NoCapture<ArgIndex<0>>, IntrWillReturn]>;

def int_evm_signextend
: Intrinsic<[llvm_i256_ty], [llvm_i256_ty, llvm_i256_ty],
[IntrNoMem, IntrWillReturn]>;
: DefaultAttrsIntrinsic<[llvm_i256_ty], [llvm_i256_ty, llvm_i256_ty],
[IntrNoMem, IntrSpeculatable]>;

def int_evm_byte : Intrinsic<[llvm_i256_ty], [llvm_i256_ty, llvm_i256_ty],
[IntrNoMem, IntrWillReturn]>;
Expand Down
76 changes: 76 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsEVM.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/LLVMContext.h"
Expand Down Expand Up @@ -1629,6 +1630,30 @@ static void computeKnownBitsFromOperator(const Operator *I,
switch (II->getIntrinsicID()) {
default:
break;
case Intrinsic::evm_signextend: {
auto *Ty = dyn_cast<IntegerType>(II->getType());
if (!Ty)
break;

unsigned BitWidth = Ty->getIntegerBitWidth();
if (BitWidth != 256)
break;

const auto *ByteIdxC = dyn_cast<ConstantInt>(II->getArgOperand(0));
if (!ByteIdxC)
break;

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = ByteIdxC->getZExtValue();
if (ByteIdx >= BitWidth / 8)
break;

computeKnownBits(I->getOperand(1), DemandedElts, Known2, Depth + 1, Q);
unsigned Width = (ByteIdx + 1) * 8;
Known = Known2.trunc(Width).sext(BitWidth);
break;
}

case Intrinsic::abs: {
computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q);
bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
Expand Down Expand Up @@ -3184,6 +3209,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
return true;
case Intrinsic::experimental_get_vector_length:
return isKnownNonZero(I->getOperand(0), Q, Depth);
case Intrinsic::evm_signextend:
return isKnownNonZero(II->getArgOperand(1), DemandedElts, Q, Depth);
default:
break;
}
Expand Down Expand Up @@ -3740,6 +3767,28 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V,
return MinSignBits;
}

static unsigned computeNumSignBitsForEVMSignExtend(const IntrinsicInst *II) {
auto *Ty = dyn_cast<IntegerType>(II->getType());
if (!Ty)
return 1;

unsigned BitWidth = Ty->getIntegerBitWidth();
if (BitWidth != 256)
return 1;

const auto *ByteIdxC = dyn_cast<ConstantInt>(II->getArgOperand(0));
if (!ByteIdxC)
return 1;

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = ByteIdxC->getZExtValue();
if (ByteIdx >= BitWidth / 8)
return 1;

unsigned Width = (ByteIdx + 1) * 8;
return BitWidth - Width + 1;
}

static unsigned ComputeNumSignBitsImpl(const Value *V,
const APInt &DemandedElts,
unsigned Depth, const SimplifyQuery &Q);
Expand Down Expand Up @@ -4070,6 +4119,8 @@ static unsigned ComputeNumSignBitsImpl(const Value *V,
switch (II->getIntrinsicID()) {
default:
break;
case Intrinsic::evm_signextend:
return computeNumSignBitsForEVMSignExtend(II);
case Intrinsic::abs:
Tmp =
ComputeNumSignBits(U->getOperand(0), DemandedElts, Depth + 1, Q);
Expand Down Expand Up @@ -9509,10 +9560,35 @@ static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
}
}

static ConstantRange getRangeForEVMSignExtend(const IntrinsicInst &II) {
unsigned BitWidth = II.getType()->getIntegerBitWidth();
if (BitWidth != 256)
return ConstantRange::getFull(BitWidth);

auto *ByteIdxC = dyn_cast<ConstantInt>(II.getArgOperand(0));
if (!ByteIdxC)
return ConstantRange::getFull(BitWidth);

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = ByteIdxC->getZExtValue();
if (ByteIdx >= BitWidth / 8)
return ConstantRange::getFull(BitWidth);

// Range that signextend produces is:
// [ -2^(width-1), 2^(width-1)-1 ] in signed space
// Since ConstantRange is [Min, Max), and Max is exclusive, we need to add 1.
unsigned Width = (ByteIdx + 1) * 8;
return ConstantRange::getNonEmpty(
APInt::getSignedMinValue(Width).sext(BitWidth),
APInt::getSignedMaxValue(Width).sext(BitWidth) + 1);
}

static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II) {
unsigned Width = II.getType()->getScalarSizeInBits();
const APInt *C;
switch (II.getIntrinsicID()) {
case Intrinsic::evm_signextend:
return getRangeForEVMSignExtend(II);
case Intrinsic::ctpop:
case Intrinsic::ctlz:
case Intrinsic::cttz:
Expand Down
24 changes: 24 additions & 0 deletions llvm/lib/IR/ConstantRange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsEVM.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/Support/Compiler.h"
Expand Down Expand Up @@ -1009,6 +1010,7 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
case Intrinsic::ctlz:
case Intrinsic::cttz:
case Intrinsic::ctpop:
case Intrinsic::evm_signextend:
return true;
default:
return false;
Expand All @@ -1018,6 +1020,8 @@ bool ConstantRange::isIntrinsicSupported(Intrinsic::ID IntrinsicID) {
ConstantRange ConstantRange::intrinsic(Intrinsic::ID IntrinsicID,
ArrayRef<ConstantRange> Ops) {
switch (IntrinsicID) {
case Intrinsic::evm_signextend:
return Ops[0].evmSignExtend(Ops[1]);
case Intrinsic::uadd_sat:
return Ops[0].uadd_sat(Ops[1]);
case Intrinsic::usub_sat:
Expand Down Expand Up @@ -1953,6 +1957,26 @@ ConstantRange ConstantRange::ctpop() const {
return CR1.unionWith(CR2);
}

ConstantRange ConstantRange::evmSignExtend(const ConstantRange &Other) const {
unsigned BitWidth = getBitWidth();
if (BitWidth != 256)
return getFull();

if (isEmptySet() || Other.isEmptySet())
return getEmpty();

if (!isSingleElement())
return getFull();

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = getSingleElement()->getZExtValue();
if (ByteIdx >= BitWidth / 8)
return getFull();

unsigned Width = (ByteIdx + 1) * 8;
return Other.truncate(Width).signExtend(BitWidth);
}

ConstantRange::OverflowResult ConstantRange::unsignedAddMayOverflow(
const ConstantRange &Other) const {
if (isEmptySet() || Other.isEmptySet())
Expand Down
50 changes: 49 additions & 1 deletion llvm/lib/Target/EVM/EVMCodegenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsEVM.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"

#include "EVM.h"

using namespace llvm;
using namespace llvm::PatternMatch;

#define DEBUG_TYPE "evm-codegen-prepare"

Expand Down Expand Up @@ -102,14 +104,60 @@ void EVMCodegenPrepare::processMemTransfer(MemTransferInst *M) {
M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), IntrID));
}

static bool optimizeICmp(ICmpInst *I) {
auto *Ty = I->getOperand(0)->getType();
if (!Ty->isIntegerTy(256))
return false;

if (I->getPredicate() == CmpInst::ICMP_ULT) {
Value *X = nullptr;
const APInt *CAdd = nullptr, *CCmp = nullptr;

// icmp ult (add x, CAdd), CCmp -> icmp eq (evm.signextend(b, x)), x
// where CCmp is a power of 2 and CAdd is twice smaller than CCmp.
if (match(I->getOperand(0), m_OneUse(m_c_Add(m_Value(X), m_APInt(CAdd)))) &&
match(I->getOperand(1), m_APInt(CCmp)) && CCmp->isPowerOf2() &&
*CAdd == CCmp->lshr(1)) {
unsigned CCmpLog2 = CCmp->logBase2();

// If CCmpLog2 is not divisible by 8, cannot use signextend.
if (CCmpLog2 % 8 != 0)
return false;

IRBuilder<> Builder(I);
unsigned ByteIdx = (CCmpLog2 / 8) - 1;

// ByteIdx should be in [0, 31].
if (ByteIdx > 31)
return false;

auto *B = ConstantInt::get(Ty, ByteIdx);
auto *SignExtend =
Builder.CreateIntrinsic(Ty, Intrinsic::evm_signextend, {B, X});
auto *NewCmp = Builder.CreateICmp(CmpInst::ICMP_EQ, SignExtend, X);
NewCmp->takeName(I);
I->replaceAllUsesWith(NewCmp);

// Remove add after icmp. If to do otherwise, assert will be triggered.
auto *ToRemove = cast<Instruction>(I->getOperand(0));
I->eraseFromParent();
ToRemove->eraseFromParent();
return true;
}
}
return false;
}

bool EVMCodegenPrepare::runOnFunction(Function &F) {
bool Changed = false;
for (auto &BB : F) {
for (auto &I : BB) {
for (auto &I : make_early_inc_range(BB)) {
if (auto *M = dyn_cast<MemTransferInst>(&I)) {
processMemTransfer(M);
Changed = true;
}
if (I.getOpcode() == Instruction::ICmp)
Changed |= optimizeICmp(cast<ICmpInst>(&I));
}
}

Expand Down
91 changes: 90 additions & 1 deletion llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,105 @@ using namespace llvm::PatternMatch;

#define DEBUG_TYPE "evmtti"

static std::optional<Instruction *> foldSignExtendToConst(InstCombiner &IC,
IntrinsicInst &II) {
constexpr unsigned BitWidth = 256;
if (!II.getType()->isIntegerTy(BitWidth))
return std::nullopt;

const auto *ByteIdxC = dyn_cast<ConstantInt>(II.getArgOperand(0));
if (!ByteIdxC)
return std::nullopt;

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = ByteIdxC->getZExtValue();
if (ByteIdx >= BitWidth / 8)
return std::nullopt;

// Compute known bits of the input.
KnownBits Known(BitWidth);
IC.computeKnownBits(II.getArgOperand(1), Known, 0, &II);

unsigned Width = (ByteIdx + 1) * 8;
APInt LowMask = APInt::getLowBitsSet(BitWidth, Width);
if (((Known.Zero | Known.One) & LowMask) == LowMask) {
APInt Folded = (Known.One & LowMask).trunc(Width).sext(BitWidth);
return IC.replaceInstUsesWith(II, ConstantInt::get(II.getType(), Folded));
}
return std::nullopt;
}

static std::optional<Instruction *> instCombineSignExtend(InstCombiner &IC,
IntrinsicInst &II) {
constexpr unsigned BitWidth = 256;
if (!II.getType()->isIntegerTy(BitWidth))
return std::nullopt;

// Fold signextend(b, signextend(b, x)) -> signextend(b, x)
Value *B = nullptr, *X = nullptr;
if (match(&II, m_Intrinsic<Intrinsic::evm_signextend>(
m_Value(B), m_Intrinsic<Intrinsic::evm_signextend>(
m_Deferred(B), m_Value(X)))))
return IC.replaceInstUsesWith(II, II.getArgOperand(1));

return std::nullopt;
// From now on, we only handle signextend with constant byte index.
const auto *ByteIdxC = dyn_cast<ConstantInt>(II.getArgOperand(0));
if (!ByteIdxC)
return std::nullopt;

// ByteIdx must be in range [0, 31].
uint64_t ByteIdx = ByteIdxC->getZExtValue();
if (ByteIdx >= BitWidth / 8)
return std::nullopt;

unsigned Width = (ByteIdx + 1) * 8;

// Fold signextend into shifts, if second operand or the result is shift
// with constant. LLVM will fuse those shifts, and will replace signextend
// with a shift, which is cheaper.
if (match(II.getArgOperand(1),
m_OneUse(m_Shift(m_Value(), m_ConstantInt()))) ||
(II.hasOneUse() &&
match(II.user_back(), m_Shift(m_Specific(&II), m_ConstantInt())))) {
unsigned ShiftAmt = BitWidth - Width;
auto *Shl = IC.Builder.CreateShl(II.getArgOperand(1), ShiftAmt);
auto *Ashr = IC.Builder.CreateAShr(Shl, ShiftAmt);
return IC.replaceInstUsesWith(II, Ashr);
}

const APInt *AndVal = nullptr;
// Match signextend(b, and(x, C))
if (match(II.getArgOperand(1), m_And(m_Value(X), m_APInt(AndVal)))) {
APInt LowMask = APInt::getLowBitsSet(BitWidth, Width);

// signextend(b, x & C) -> signextend(b, x)
// If and fully preservs low bits, we can drop it.
if ((*AndVal & LowMask) == LowMask)
return IC.replaceOperand(II, 1, X);

// signextend(b, x & C) -> (x & C)
// If and doesn't touch upper bits, and clears sign bit, we can drop
// signextend.
APInt SignBit = APInt(BitWidth, 1).shl(Width - 1);
if ((*AndVal & ~LowMask).isZero() && (*AndVal & SignBit).isZero())
return IC.replaceInstUsesWith(II, II.getArgOperand(1));

// signextend(b, x & C) -> 0
// If and clears all low bits, result is always 0.
if ((*AndVal & LowMask).isZero())
return IC.replaceInstUsesWith(II,
ConstantInt::getNullValue(II.getType()));
}

// and(signextend(b, x), C) -> and(x, C)
// If and doesn't touch upper bits, we can drop signextend.
if (II.hasOneUse() &&
match(II.user_back(), m_And(m_Specific(&II), m_APInt(AndVal)))) {
APInt LowMask = APInt::getLowBitsSet(BitWidth, Width);
if ((*AndVal & ~LowMask).isZero())
return IC.replaceInstUsesWith(II, II.getArgOperand(1));
}
return foldSignExtendToConst(IC, II);
}

std::optional<Instruction *>
Expand Down
Loading
Loading