Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,7 +1531,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
ISD::INSERT_VECTOR_ELT, ISD::ABS, ISD::CTPOP,
ISD::VECTOR_SHUFFLE});
ISD::VSELECT, ISD::VECTOR_SHUFFLE});
if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
Expand Down Expand Up @@ -16798,6 +16798,53 @@ static SDValue useInversedSetcc(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static bool matchSelectAddSub(SDValue TrueVal, SDValue FalseVal, bool &SwapCC) {
if (!TrueVal.hasOneUse() || !FalseVal.hasOneUse())
return false;

SwapCC = false;
if (TrueVal.getOpcode() == ISD::SUB && FalseVal.getOpcode() == ISD::ADD) {
std::swap(TrueVal, FalseVal);
SwapCC = true;
}

if (TrueVal.getOpcode() != ISD::ADD || FalseVal.getOpcode() != ISD::SUB)
return false;

SDValue A = FalseVal.getOperand(0);
SDValue B = FalseVal.getOperand(1);
// Add is associative, so check both orders
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

associative->commutative?

return ((TrueVal.getOperand(0) == A && TrueVal.getOperand(1) == B) ||
(TrueVal.getOperand(1) == A && TrueVal.getOperand(0) == B));
}

/// Convert vselect CC, (add a, b), (sub a, b) to add a, (vselect CC, -b, b).
/// This allows us match a vadd.vv fed by a masked vrsub, which reduces
/// register pressure over the add followed by masked vsub sequence.
static SDValue performVSELECTCombine(SDNode *N, SelectionDAG &DAG) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue CC = N->getOperand(0);
SDValue TrueVal = N->getOperand(1);
SDValue FalseVal = N->getOperand(2);

bool SwapCC;
if (!matchSelectAddSub(TrueVal, FalseVal, SwapCC))
return SDValue();

SDValue Sub = SwapCC ? TrueVal : FalseVal;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

static SDValue performSELECTCombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
if (SDValue Folded = foldSelectOfCTTZOrCTLZ(N, DAG))
Expand Down Expand Up @@ -17077,20 +17124,48 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
}

/// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
/// during the combine phase before type legalization, and relies on
/// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
/// for the source mask.
static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget,
const RISCVTargetLowering &TLI) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
const unsigned ElementSize = VT.getScalarSizeInBits();
const unsigned NumElts = VT.getVectorNumElements();
SDValue V1 = N->getOperand(0);
SDValue V2 = N->getOperand(1);
ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(N)->getMask();
MVT XLenVT = Subtarget.getXLenVT();

// Recognized a disguised select of add/sub.
bool SwapCC;
if (ShuffleVectorInst::isSelectMask(Mask, NumElts) &&
matchSelectAddSub(V1, V2, SwapCC)) {
SDValue Sub = SwapCC ? V1 : V2;
SDValue A = Sub.getOperand(0);
SDValue B = Sub.getOperand(1);

SmallVector<SDValue> MaskVals;
for (int MaskIndex : Mask) {
bool SelectMaskVal = (MaskIndex < (int)NumElts);
MaskVals.push_back(DAG.getConstant(SelectMaskVal, DL, XLenVT));
}
assert(MaskVals.size() == NumElts && "Unexpected select-like shuffle");
MVT MaskVT = MVT::getVectorVT(MVT::i1, NumElts);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need EVT here since we could be pre-type legalization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I do. That's what I get for copying code around without cross checking. Will update with a fix Monday.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to EVT, but oddly, my attempt at writing a test case failed. I tried a number of illegal types, and couldn't get the crash to trigger.

SDValue CC = DAG.getBuildVector(MaskVT, DL, MaskVals);

// Arrange the select such that we can match a masked
// vrsub.vi to perform the conditional negate
SDValue NegB = DAG.getNegative(B, DL, VT);
if (!SwapCC)
CC = DAG.getLogicalNOT(DL, CC, CC->getValueType(0));
SDValue NewB = DAG.getNode(ISD::VSELECT, DL, VT, CC, NegB, B);
return DAG.getNode(ISD::ADD, DL, VT, A, NewB);
}

// Custom legalize <N x i128> or <N x i256> to <M x ELEN>. This runs
// during the combine phase before type legalization, and relies on
// DAGCombine not undoing the transform if isShuffleMaskLegal returns false
// for the source mask.
if (TLI.isTypeLegal(VT) || ElementSize <= Subtarget.getELen() ||
!isPowerOf2_64(ElementSize) || VT.getVectorNumElements() % 2 != 0 ||
VT.isFloatingPoint() || TLI.isShuffleMaskLegal(Mask, VT))
Expand All @@ -17107,7 +17182,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getBitcast(VT, Res);
}


static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {

Expand Down Expand Up @@ -17781,6 +17855,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
return performSELECTCombine(N, DAG, Subtarget);
case ISD::VSELECT:
return performVSELECTCombine(N, DAG);
case RISCVISD::CZERO_EQZ:
case RISCVISD::CZERO_NEZ: {
SDValue Val = N->getOperand(0);
Expand Down
Loading
Loading