2929#include "llvm/CodeGen/MachineInstrBuilder.h"
3030#include "llvm/CodeGen/MachineJumpTableInfo.h"
3131#include "llvm/CodeGen/MachineRegisterInfo.h"
32+ #include "llvm/CodeGen/SDPatternMatch.h"
3233#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
3334#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
3435#include "llvm/CodeGen/ValueTypes.h"
5051#include <optional>
5152
5253using namespace llvm;
54+ using namespace llvm::PatternMatch;
5355
5456#define DEBUG_TYPE "riscv-lower"
5557
@@ -14322,7 +14324,6 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst
1432214324 auto *AddConstNode = dyn_cast<ConstantSDNode>(AddI->getOperand(1));
1432314325 if (!AddConstNode)
1432414326 return false;
14325- AddConst = AddConstNode->getSExtValue();
1432614327
1432714328 SDValue SHLVal = AddI->getOperand(0);
1432814329 if (SHLVal->getOpcode() != ISD::SHL)
@@ -14332,10 +14333,11 @@ static bool checkAddiForShift(SDValue AddI, int64_t &AddConst, int64_t &ShlConst
1433214333 if (!ShiftNode)
1433314334 return false;
1433414335
14335- ShlConst = ShiftNode->getSExtValue();
1433614336 if (ShlConst < 1 || ShlConst > 3)
1433714337 return false;
1433814338
14339+ AddConst = AddConstNode->getSExtValue();
14340+ ShlConst = ShiftNode->getSExtValue();
1433914341 return true;
1434014342}
1434114343
@@ -14349,7 +14351,7 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG,
1434914351
1435014352 // Skip for vector types and larger types.
1435114353 EVT VT = N->getValueType(0);
14352- if (VT.isVector() || VT.getSizeInBits() > Subtarget.getXLen ())
14354+ if (VT != Subtarget.getXLenVT ())
1435314355 return SDValue();
1435414356
1435514357 // Looking for a reg-reg add and not an addi.
@@ -14358,18 +14360,22 @@ static SDValue combineShlAddIAdd(SDNode *N, SelectionDAG &DAG,
1435814360
1435914361 SDValue AddI = N->getOperand(0);
1436014362 SDValue Other = N->getOperand(1);
14361- bool LHSIsAdd = AddI.getOpcode() == ISD::ADD;
14362- bool RHSIsAdd = Other.getOpcode() == ISD::ADD;
14363- int64_t AddConst;
14364- int64_t ShlConst;
14365-
14366- // At least one add is required.
14367- if (!(LHSIsAdd || RHSIsAdd))
14363+ bool LHSIsAddI = SDPatternMatch::sd_match(
14364+ AddI, SDPatternMatch::m_Add(SDPatternMatch::m_Value(),
14365+ SDPatternMatch::m_ConstInt()));
14366+ bool RHSIsAddI = SDPatternMatch::sd_match(
14367+ Other, SDPatternMatch::m_Add(SDPatternMatch::m_Value(),
14368+ SDPatternMatch::m_ConstInt()));
14369+ int64_t AddConst = 0;
14370+ int64_t ShlConst = 0;
14371+
14372+ // At least one addi is required.
14373+ if (!LHSIsAddI && !RHSIsAddI)
1436814374 return SDValue();
1436914375
1437014376 // If the LHS is not the result of an add or both sides are results of an add, but
1437114377 // the LHS does not have the desired structure with a shift, swap the operands.
14372- if (!LHSIsAdd || (LHSIsAdd && RHSIsAdd && !checkAddiForShift(AddI, AddConst, ShlConst)))
14378+ if (!LHSIsAddI || (RHSIsAddI && !checkAddiForShift(AddI, AddConst, ShlConst)))
1437314379 std::swap(AddI, Other);
1437414380
1437514381 // We simply need to ensure AddI has the desired structure.
0 commit comments