Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,59 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
return SDValue();
}

// Helper to attempt to return a cheaper, bit-inverted version of \p V.
static SDValue isNOT(SDValue V, SelectionDAG &DAG) {
// TODO: don't always ignore oneuse constraints.
V = peekThroughBitcasts(V);
EVT VT = V.getValueType();

// Match not(xor X, -1) -> X.
if (V.getOpcode() == ISD::XOR &&
(ISD::isBuildVectorAllOnes(V.getOperand(1).getNode()) ||
isAllOnesConstant(V.getOperand(1))))
return V.getOperand(0);

// Match not(extract_subvector(not(X)) -> extract_subvector(X).
if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
(isNullConstant(V.getOperand(1)) || V.getOperand(0).hasOneUse())) {
if (SDValue Not = isNOT(V.getOperand(0), DAG)) {
Not = DAG.getBitcast(V.getOperand(0).getValueType(), Not);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Not), VT, Not,
V.getOperand(1));
}
}

// Match not(SplatVector(not(X)) -> SplatVector(X).
if (V.getOpcode() == ISD::BUILD_VECTOR) {
if (SDValue SplatValue =
cast<BuildVectorSDNode>(V.getNode())->getSplatValue()) {
if (!V->isOnlyUserOf(SplatValue.getNode()))
return SDValue();

if (SDValue Not = isNOT(SplatValue, DAG)) {
Not = DAG.getBitcast(V.getOperand(0).getValueType(), Not);
return DAG.getSplat(VT, SDLoc(Not), Not);
}
}
}

// Match not(or(not(X),not(Y))) -> and(X, Y).
if (V.getOpcode() == ISD::OR && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
V.getOperand(0).hasOneUse() && V.getOperand(1).hasOneUse()) {
// TODO: Handle cases with single NOT operand -> VANDN
if (SDValue Op1 = isNOT(V.getOperand(1), DAG))
if (SDValue Op0 = isNOT(V.getOperand(0), DAG))
return DAG.getNode(ISD::AND, SDLoc(V), VT, DAG.getBitcast(VT, Op0),
DAG.getBitcast(VT, Op1));
}

// TODO: Add more matching patterns. Such as,
// not(concat_vectors(not(X), not(Y))) -> concat_vectors(X, Y).
// not(slt(C, X)) -> slt(X - 1, C)

return SDValue();
}

SDValue LoongArchTargetLowering::lowerConstantFP(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
Expand Down Expand Up @@ -4939,6 +4992,33 @@ void LoongArchTargetLowering::ReplaceNodeResults(
}
}

/// Try to fold: (and (xor X, -1), Y) -> (vandn X, Y).
static SDValue combineAndNotIntoVANDN(SDNode *N, const SDLoc &DL,
SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::AND && "Unexpected opcode combine into ANDN");

MVT VT = N->getSimpleValueType(0);
if (!VT.is128BitVector() && !VT.is256BitVector())
return SDValue();

SDValue X, Y;
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);

if (SDValue Not = isNOT(N0, DAG)) {
X = Not;
Y = N1;
} else if (SDValue Not = isNOT(N1, DAG)) {
X = Not;
Y = N0;
} else
return SDValue();

X = DAG.getBitcast(VT, X);
Y = DAG.getBitcast(VT, Y);
return DAG.getNode(LoongArchISD::VANDN, DL, VT, X, Y);
}

static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
Expand All @@ -4956,6 +5036,9 @@ static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
SDValue NewOperand;
MVT GRLenVT = Subtarget.getGRLenVT();

if (SDValue R = combineAndNotIntoVANDN(N, DL, DAG))
return R;

// BSTRPICK requires the 32S feature.
if (!Subtarget.has32S())
return SDValue();
Expand Down Expand Up @@ -6628,6 +6711,69 @@ performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

/// Do target-specific dag combines on LoongArchISD::VANDN nodes.
static SDValue performVANDNCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
MVT VT = N->getSimpleValueType(0);
SDLoc DL(N);

// VANDN(undef, x) -> 0
// VANDN(x, undef) -> 0
if (N0.isUndef() || N1.isUndef())
return DAG.getConstant(0, DL, VT);

// VANDN(0, x) -> x
if (ISD::isBuildVectorAllZeros(N0.getNode()))
return N1;

// VANDN(x, 0) -> 0
if (ISD::isBuildVectorAllZeros(N1.getNode()))
return DAG.getConstant(0, DL, VT);

// VANDN(x, -1) -> NOT(x) -> XOR(x, -1)
if (ISD::isBuildVectorAllOnes(N1.getNode()))
return DAG.getNOT(DL, N0, VT);

// Turn VANDN back to AND if input is inverted.
if (SDValue Not = isNOT(N0, DAG))
return DAG.getNode(ISD::AND, DL, VT, DAG.getBitcast(VT, Not), N1);

// Folds for better commutativity:
if (N1->hasOneUse()) {
// VANDN(x,NOT(y)) -> AND(NOT(x),NOT(y)) -> NOT(OR(X,Y)).
if (SDValue Not = isNOT(N1, DAG))
return DAG.getNOT(
DL, DAG.getNode(ISD::OR, DL, VT, N0, DAG.getBitcast(VT, Not)), VT);

// VANDN(x, SplatVector(Imm)) -> AND(NOT(x), NOT(SplatVector(~Imm)))
// -> NOT(OR(x, SplatVector(-Imm))
// Combination is performed only when VT is v16i8/v32i8, using `vnori.b` to
// gain benefits.
if (!DCI.isBeforeLegalizeOps() && (VT == MVT::v16i8 || VT == MVT::v32i8) &&
N1.getOpcode() == ISD::BUILD_VECTOR) {
if (SDValue SplatValue =
cast<BuildVectorSDNode>(N1.getNode())->getSplatValue()) {
if (!N1->isOnlyUserOf(SplatValue.getNode()))
return SDValue();

if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
uint8_t NCVal = static_cast<uint8_t>(~(C->getSExtValue()));
SDValue Not =
DAG.getSplat(VT, DL, DAG.getTargetConstant(NCVal, DL, MVT::i8));
return DAG.getNOT(
DL, DAG.getNode(ISD::OR, DL, VT, N0, DAG.getBitcast(VT, Not)),
VT);
}
}
}
}

return SDValue();
}

SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
Expand Down Expand Up @@ -6663,6 +6809,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget);
case ISD::EXTRACT_VECTOR_ELT:
return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::VANDN:
return performVANDNCombine(N, DAG, DCI, Subtarget);
}
return SDValue();
}
Expand Down Expand Up @@ -7454,6 +7602,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VPICK_SEXT_ELT)
NODE_NAME_CASE(VPICK_ZEXT_ELT)
NODE_NAME_CASE(VREPLVE)
NODE_NAME_CASE(VANDN)
NODE_NAME_CASE(VALL_ZERO)
NODE_NAME_CASE(VANY_ZERO)
NODE_NAME_CASE(VALL_NONZERO)
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/LoongArch/LoongArchISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ enum NodeType : unsigned {
VBSLL,
VBSRL,

// Vector bit operation
VANDN,

// Scalar load broadcast to vector
VLDREPL,

Expand Down
26 changes: 13 additions & 13 deletions llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ def : Pat<(vnot (or (vt LASX256:$xj), (vt LASX256:$xk))),
(XVNOR_V LASX256:$xj, LASX256:$xk)>;
// XVANDN_V
foreach vt = [v32i8, v16i16, v8i32, v4i64] in
def : Pat<(and (vt (vnot LASX256:$xj)), (vt LASX256:$xk)),
def : Pat<(loongarch_vandn (vt LASX256:$xj), (vt LASX256:$xk)),
(XVANDN_V LASX256:$xj, LASX256:$xk)>;
// XVORN_V
foreach vt = [v32i8, v16i16, v8i32, v4i64] in
Expand Down Expand Up @@ -1449,25 +1449,25 @@ defm : PatXr<ctlz, "XVCLZ">;
defm : PatXr<ctpop, "XVPCNT">;

// XVBITCLR_{B/H/W/D}
def : Pat<(and v32i8:$xj, (vnot (shl vsplat_imm_eq_1, v32i8:$xk))),
def : Pat<(loongarch_vandn (v32i8 (shl vsplat_imm_eq_1, v32i8:$xk)), v32i8:$xj),
(v32i8 (XVBITCLR_B v32i8:$xj, v32i8:$xk))>;
def : Pat<(and v16i16:$xj, (vnot (shl vsplat_imm_eq_1, v16i16:$xk))),
def : Pat<(loongarch_vandn (v16i16 (shl vsplat_imm_eq_1, v16i16:$xk)), v16i16:$xj),
(v16i16 (XVBITCLR_H v16i16:$xj, v16i16:$xk))>;
def : Pat<(and v8i32:$xj, (vnot (shl vsplat_imm_eq_1, v8i32:$xk))),
def : Pat<(loongarch_vandn (v8i32 (shl vsplat_imm_eq_1, v8i32:$xk)), v8i32:$xj),
(v8i32 (XVBITCLR_W v8i32:$xj, v8i32:$xk))>;
def : Pat<(and v4i64:$xj, (vnot (shl vsplat_imm_eq_1, v4i64:$xk))),
def : Pat<(loongarch_vandn (v4i64 (shl vsplat_imm_eq_1, v4i64:$xk)), v4i64:$xj),
(v4i64 (XVBITCLR_D v4i64:$xj, v4i64:$xk))>;
def : Pat<(and v32i8:$xj, (vnot (shl vsplat_imm_eq_1,
(vsplati8imm7 v32i8:$xk)))),
def : Pat<(loongarch_vandn (v32i8 (shl vsplat_imm_eq_1,
(vsplati8imm7 v32i8:$xk))), v32i8:$xj),
(v32i8 (XVBITCLR_B v32i8:$xj, v32i8:$xk))>;
def : Pat<(and v16i16:$xj, (vnot (shl vsplat_imm_eq_1,
(vsplati16imm15 v16i16:$xk)))),
def : Pat<(loongarch_vandn (v16i16 (shl vsplat_imm_eq_1,
(vsplati16imm15 v16i16:$xk))), v16i16:$xj),
(v16i16 (XVBITCLR_H v16i16:$xj, v16i16:$xk))>;
def : Pat<(and v8i32:$xj, (vnot (shl vsplat_imm_eq_1,
(vsplati32imm31 v8i32:$xk)))),
def : Pat<(loongarch_vandn (v8i32 (shl vsplat_imm_eq_1,
(vsplati32imm31 v8i32:$xk))), v8i32:$xj),
(v8i32 (XVBITCLR_W v8i32:$xj, v8i32:$xk))>;
def : Pat<(and v4i64:$xj, (vnot (shl vsplat_imm_eq_1,
(vsplati64imm63 v4i64:$xk)))),
def : Pat<(loongarch_vandn (v4i64 (shl vsplat_imm_eq_1,
(vsplati64imm63 v4i64:$xk))), v4i64:$xj),
(v4i64 (XVBITCLR_D v4i64:$xj, v4i64:$xk))>;

// XVBITCLRI_{B/H/W/D}
Expand Down
27 changes: 14 additions & 13 deletions llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def loongarch_vpackev: SDNode<"LoongArchISD::VPACKEV", SDT_LoongArchV2R>;
def loongarch_vpackod: SDNode<"LoongArchISD::VPACKOD", SDT_LoongArchV2R>;
def loongarch_vilvl: SDNode<"LoongArchISD::VILVL", SDT_LoongArchV2R>;
def loongarch_vilvh: SDNode<"LoongArchISD::VILVH", SDT_LoongArchV2R>;
def loongarch_vandn: SDNode<"LoongArchISD::VANDN", SDT_LoongArchV2R>;

def loongarch_vshuf4i: SDNode<"LoongArchISD::VSHUF4I", SDT_LoongArchV1RUimm>;
def loongarch_vshuf4i_d : SDNode<"LoongArchISD::VSHUF4I", SDT_LoongArchV2RUimm>;
Expand Down Expand Up @@ -1586,7 +1587,7 @@ def : Pat<(vnot (or (vt LSX128:$vj), (vt LSX128:$vk))),
(VNOR_V LSX128:$vj, LSX128:$vk)>;
// VANDN_V
foreach vt = [v16i8, v8i16, v4i32, v2i64] in
def : Pat<(and (vt (vnot LSX128:$vj)), (vt LSX128:$vk)),
def : Pat<(loongarch_vandn (vt LSX128:$vj), (vt LSX128:$vk)),
(VANDN_V LSX128:$vj, LSX128:$vk)>;
// VORN_V
foreach vt = [v16i8, v8i16, v4i32, v2i64] in
Expand Down Expand Up @@ -1640,25 +1641,25 @@ defm : PatVr<ctlz, "VCLZ">;
defm : PatVr<ctpop, "VPCNT">;

// VBITCLR_{B/H/W/D}
def : Pat<(and v16i8:$vj, (vnot (shl vsplat_imm_eq_1, v16i8:$vk))),
def : Pat<(loongarch_vandn (v16i8 (shl vsplat_imm_eq_1, v16i8:$vk)), v16i8:$vj),
(v16i8 (VBITCLR_B v16i8:$vj, v16i8:$vk))>;
def : Pat<(and v8i16:$vj, (vnot (shl vsplat_imm_eq_1, v8i16:$vk))),
def : Pat<(loongarch_vandn (v8i16 (shl vsplat_imm_eq_1, v8i16:$vk)), v8i16:$vj),
(v8i16 (VBITCLR_H v8i16:$vj, v8i16:$vk))>;
def : Pat<(and v4i32:$vj, (vnot (shl vsplat_imm_eq_1, v4i32:$vk))),
def : Pat<(loongarch_vandn (v4i32 (shl vsplat_imm_eq_1, v4i32:$vk)), v4i32:$vj),
(v4i32 (VBITCLR_W v4i32:$vj, v4i32:$vk))>;
def : Pat<(and v2i64:$vj, (vnot (shl vsplat_imm_eq_1, v2i64:$vk))),
def : Pat<(loongarch_vandn (v2i64 (shl vsplat_imm_eq_1, v2i64:$vk)), v2i64:$vj),
(v2i64 (VBITCLR_D v2i64:$vj, v2i64:$vk))>;
def : Pat<(and v16i8:$vj, (vnot (shl vsplat_imm_eq_1,
(vsplati8imm7 v16i8:$vk)))),
def : Pat<(loongarch_vandn (v16i8 (shl vsplat_imm_eq_1,
(vsplati8imm7 v16i8:$vk))), v16i8:$vj),
(v16i8 (VBITCLR_B v16i8:$vj, v16i8:$vk))>;
def : Pat<(and v8i16:$vj, (vnot (shl vsplat_imm_eq_1,
(vsplati16imm15 v8i16:$vk)))),
def : Pat<(loongarch_vandn (v8i16 (shl vsplat_imm_eq_1,
(vsplati16imm15 v8i16:$vk))), v8i16:$vj),
(v8i16 (VBITCLR_H v8i16:$vj, v8i16:$vk))>;
def : Pat<(and v4i32:$vj, (vnot (shl vsplat_imm_eq_1,
(vsplati32imm31 v4i32:$vk)))),
def : Pat<(loongarch_vandn (v4i32 (shl vsplat_imm_eq_1,
(vsplati32imm31 v4i32:$vk))), v4i32:$vj),
(v4i32 (VBITCLR_W v4i32:$vj, v4i32:$vk))>;
def : Pat<(and v2i64:$vj, (vnot (shl vsplat_imm_eq_1,
(vsplati64imm63 v2i64:$vk)))),
def : Pat<(loongarch_vandn (v2i64 (shl vsplat_imm_eq_1,
(vsplati64imm63 v2i64:$vk))), v2i64:$vj),
(v2i64 (VBITCLR_D v2i64:$vj, v2i64:$vk))>;

// VBITCLRI_{B/H/W/D}
Expand Down
Loading