diff --git a/llvm/lib/Target/EVM/EVMCodegenPrepare.cpp b/llvm/lib/Target/EVM/EVMCodegenPrepare.cpp index 4923d58e231c..ad16f8e42ec8 100644 --- a/llvm/lib/Target/EVM/EVMCodegenPrepare.cpp +++ b/llvm/lib/Target/EVM/EVMCodegenPrepare.cpp @@ -20,11 +20,14 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsEVM.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" +#include "llvm/Transforms/InstCombine/InstCombiner.h" #include "EVM.h" using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "evm-codegen-prepare" @@ -102,14 +105,93 @@ void EVMCodegenPrepare::processMemTransfer(MemTransferInst *M) { M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), IntrID)); } +static bool optimizeAShrInst(Instruction *I) { + auto *Ty = I->getType(); + unsigned BitWidth = Ty->getIntegerBitWidth(); + if (BitWidth != 256) + return false; + + // Fold ashr(shl(x, c), c) -> signextend(((256 - c) / 8) - 1, x) + // where c is a constant and divisible by 8. + Value *X = nullptr; + ConstantInt *ShiftAmt = nullptr; + if (match(I->getOperand(0), + m_OneUse(m_Shl(m_Value(X), m_ConstantInt(ShiftAmt)))) && + match(I->getOperand(1), m_Specific(ShiftAmt)) && + ShiftAmt->getZExtValue() % 8 == 0) { + IRBuilder<> Builder(I); + unsigned ByteIdx = ((BitWidth - ShiftAmt->getZExtValue()) / 8) - 1; + auto *B = ConstantInt::get(Ty, ByteIdx); + auto *SignExtend = + Builder.CreateIntrinsic(Ty, Intrinsic::evm_signextend, {B, X}); + SignExtend->takeName(I); + I->replaceAllUsesWith(SignExtend); + + // Remove shl after ashr. If to do otherwise, assert will be triggered. + auto *ToRemove = cast(I->getOperand(0)); + I->eraseFromParent(); + ToRemove->eraseFromParent(); + return true; + } + return false; +} + +static bool optimizeICmp(ICmpInst *I) { + auto *Ty = I->getOperand(0)->getType(); + if (!Ty->isIntegerTy(256)) + return false; + + if (I->getPredicate() == CmpInst::ICMP_ULT) { + Value *X = nullptr; + const APInt *CAdd = nullptr, *CCmp = nullptr; + + // icmp ult (add x, CAdd), CCmp -> icmp eq (evm.signextend(b, x)), x + // where CCmp is a power of 2 and CAdd is twice smaller than CCmp. + if (match(I->getOperand(0), m_OneUse(m_c_Add(m_Value(X), m_APInt(CAdd)))) && + match(I->getOperand(1), m_APInt(CCmp)) && CCmp->isPowerOf2() && + *CAdd == CCmp->lshr(1)) { + unsigned CCmpLog2 = CCmp->logBase2(); + + // If CCmpLog2 is not divisible by 8, cannot use signextend. + if (CCmpLog2 % 8 != 0) + return false; + + IRBuilder<> Builder(I); + unsigned ByteIdx = (CCmpLog2 / 8) - 1; + + // ByteIdx should be in [0, 31]. + if (ByteIdx > 31) + return false; + + auto *B = ConstantInt::get(Ty, ByteIdx); + auto *SignExtend = + Builder.CreateIntrinsic(Ty, Intrinsic::evm_signextend, {B, X}); + auto *NewCmp = Builder.CreateICmp(CmpInst::ICMP_EQ, SignExtend, X); + NewCmp->takeName(I); + I->replaceAllUsesWith(NewCmp); + + // Remove add after icmp. If to do otherwise, assert will be triggered. + auto *ToRemove = cast(I->getOperand(0)); + I->eraseFromParent(); + ToRemove->eraseFromParent(); + return true; + } + } + return false; +} + bool EVMCodegenPrepare::runOnFunction(Function &F) { bool Changed = false; for (auto &BB : F) { - for (auto &I : BB) { + for (auto &I : make_early_inc_range(BB)) { if (auto *M = dyn_cast(&I)) { processMemTransfer(M); Changed = true; } + if (I.getOpcode() == Instruction::AShr) + Changed |= optimizeAShrInst(&I); + else if (I.getOpcode() == Instruction::ICmp) + Changed |= optimizeICmp(cast(&I)); } } diff --git a/llvm/lib/Target/EVM/EVMTargetMachine.cpp b/llvm/lib/Target/EVM/EVMTargetMachine.cpp index 16a512d1dcd4..c36d0e05906c 100644 --- a/llvm/lib/Target/EVM/EVMTargetMachine.cpp +++ b/llvm/lib/Target/EVM/EVMTargetMachine.cpp @@ -224,8 +224,8 @@ bool EVMPassConfig::addPreISel() { } void EVMPassConfig::addCodeGenPrepare() { - addPass(createEVMCodegenPreparePass()); TargetPassConfig::addCodeGenPrepare(); + addPass(createEVMCodegenPreparePass()); } bool EVMPassConfig::addInstSelector() { diff --git a/llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp b/llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp index ad7c259757f8..9acb7e9c17e6 100644 --- a/llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp +++ b/llvm/lib/Target/EVM/EVMTargetTransformInfo.cpp @@ -21,6 +21,28 @@ using namespace llvm::PatternMatch; static std::optional instCombineSignExtend(InstCombiner &IC, IntrinsicInst &II) { + unsigned BitWidth = II.getType()->getIntegerBitWidth(); + if (BitWidth != 256) + return std::nullopt; + + // Unfold signextend(c, x) -> + // ashr(shl(x, 256 - (c + 1) * 8), 256 - (c + 1) * 8) + // where c is a constant integer. + ConstantInt *C = nullptr; + if (match(II.getArgOperand(0), m_ConstantInt(C))) { + const APInt &B = C->getValue(); + + // If the signextend is larger than 31 bits, leave constant + // folding to handle it. + if (B.uge(APInt(BitWidth, (BitWidth / 8) - 1))) + return std::nullopt; + + unsigned ShiftAmt = BitWidth - ((B.getZExtValue() + 1) * 8); + auto *Shl = IC.Builder.CreateShl(II.getArgOperand(1), ShiftAmt); + auto *Ashr = IC.Builder.CreateAShr(Shl, ShiftAmt); + return IC.replaceInstUsesWith(II, Ashr); + } + // Fold signextend(b, signextend(b, x)) -> signextend(b, x) Value *B = nullptr, *X = nullptr; if (match(&II, m_Intrinsic( diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index b3d439a6f113..1a01c662760e 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -752,6 +752,8 @@ static bool narrowSDivOrSRem(BinaryOperator *Instr, const ConstantRange &LCR, assert(Instr->getOpcode() == Instruction::SDiv || Instr->getOpcode() == Instruction::SRem); + return false; + // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. unsigned OrigWidth = Instr->getType()->getScalarSizeInBits(); @@ -885,6 +887,7 @@ static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); + return false; // Find the smallest power of two bitwidth that's sufficient to hold Instr's // operands. diff --git a/llvm/test/CodeGen/EVM/O3-pipeline.ll b/llvm/test/CodeGen/EVM/O3-pipeline.ll index 5185960da01a..2cec9ff98f34 100644 --- a/llvm/test/CodeGen/EVM/O3-pipeline.ll +++ b/llvm/test/CodeGen/EVM/O3-pipeline.ll @@ -55,10 +55,8 @@ target triple = "evm" ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Natural Loop Information ; CHECK-NEXT: TLS Variable Hoist -; CHECK-NEXT: Final transformations before code generation -; CHECK-NEXT: Dominator Tree Construction -; CHECK-NEXT: Natural Loop Information ; CHECK-NEXT: CodeGen Prepare +; CHECK-NEXT: Final transformations before code generation ; CHECK-NEXT: Lower invoke and unwind, for unwindless code generators ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: CallGraph Construction diff --git a/llvm/test/CodeGen/EVM/fold-signextend.ll b/llvm/test/CodeGen/EVM/fold-signextend.ll index 559d46381321..015fe1deb1a1 100644 --- a/llvm/test/CodeGen/EVM/fold-signextend.ll +++ b/llvm/test/CodeGen/EVM/fold-signextend.ll @@ -7,7 +7,8 @@ target triple = "evm" define i256 @test_const(i256 %x) { ; CHECK-LABEL: define i256 @test_const( ; CHECK-SAME: i256 [[X:%.*]]) { -; CHECK-NEXT: [[SIGNEXT1:%.*]] = call i256 @llvm.evm.signextend(i256 15, i256 [[X]]) +; CHECK-NEXT: [[TMP1:%.*]] = shl i256 [[X]], 128 +; CHECK-NEXT: [[SIGNEXT1:%.*]] = ashr exact i256 [[TMP1]], 128 ; CHECK-NEXT: ret i256 [[SIGNEXT1]] ; %signext1 = call i256 @llvm.evm.signextend(i256 15, i256 %x) @@ -18,8 +19,8 @@ define i256 @test_const(i256 %x) { define i256 @test_const_ne(i256 %x) { ; CHECK-LABEL: define i256 @test_const_ne( ; CHECK-SAME: i256 [[X:%.*]]) { -; CHECK-NEXT: [[SIGNEXT1:%.*]] = call i256 @llvm.evm.signextend(i256 15, i256 [[X]]) -; CHECK-NEXT: [[SIGNEXT2:%.*]] = call i256 @llvm.evm.signextend(i256 10, i256 [[SIGNEXT1]]) +; CHECK-NEXT: [[TMP1:%.*]] = shl i256 [[X]], 168 +; CHECK-NEXT: [[SIGNEXT2:%.*]] = ashr exact i256 [[TMP1]], 168 ; CHECK-NEXT: ret i256 [[SIGNEXT2]] ; %signext1 = call i256 @llvm.evm.signextend(i256 15, i256 %x)