Skip to content

Commit 1c1ee6e

Browse files
committed
[LoongArch][DAGCombiner] Combine xor (and ..) to vandn
After this commit, DAGCombiner will have more opportunities to perform vector folding. This patch includes several foldings, as follows: - VANDN(x,NOT(y)) -> AND(NOT(x),NOT(y)) -> NOT(OR(X,Y)) - VANDN(x, SplatVector(Imm)) -> AND(NOT(x), NOT(SplatVector(~Imm)))
1 parent 0dffa25 commit 1c1ee6e

File tree

4 files changed

+185
-26
lines changed

4 files changed

+185
-26
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4939,6 +4939,96 @@ void LoongArchTargetLowering::ReplaceNodeResults(
49394939
}
49404940
}
49414941

4942+
// Check if all elements in build_vector are the same or undef, and if so,
4943+
// return true and set the splat element in SplatValue.
4944+
static bool isSplatOrUndef(SDNode *N, SDValue &SplatValue) {
4945+
if (N->getOpcode() != ISD::BUILD_VECTOR)
4946+
return false;
4947+
for (SDValue Op : N->ops()) {
4948+
if (!Op.isUndef() && SplatValue && Op != SplatValue)
4949+
return false;
4950+
if (!Op.isUndef())
4951+
SplatValue = Op;
4952+
}
4953+
return true;
4954+
}
4955+
4956+
// Helper to attempt to return a cheaper, bit-inverted version of \p V.
4957+
static SDValue isNOT(SDValue V, SelectionDAG &DAG) {
4958+
// TODO: don't always ignore oneuse constraints.
4959+
V = peekThroughBitcasts(V);
4960+
EVT VT = V.getValueType();
4961+
4962+
// Match not(xor X, -1) -> X.
4963+
if (V.getOpcode() == ISD::XOR &&
4964+
(ISD::isBuildVectorAllOnes(V.getOperand(1).getNode()) ||
4965+
isAllOnesConstant(V.getOperand(1))))
4966+
return V.getOperand(0);
4967+
4968+
// Match not(extract_subvector(not(X)) -> extract_subvector(X).
4969+
if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
4970+
(isNullConstant(V.getOperand(1)) || V.getOperand(0).hasOneUse())) {
4971+
if (SDValue Not = isNOT(V.getOperand(0), DAG)) {
4972+
Not = DAG.getBitcast(V.getOperand(0).getValueType(), Not);
4973+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(Not), VT, Not,
4974+
V.getOperand(1));
4975+
}
4976+
}
4977+
4978+
// Match not(SplatVector(not(X)) -> SplatVector(X).
4979+
SDValue SplatValue;
4980+
if (isSplatOrUndef(V.getNode(), SplatValue) &&
4981+
V->isOnlyUserOf(SplatValue.getNode())) {
4982+
if (SDValue Not = isNOT(SplatValue, DAG)) {
4983+
Not = DAG.getBitcast(V.getOperand(0).getValueType(), Not);
4984+
return DAG.getSplat(VT, SDLoc(Not), Not);
4985+
}
4986+
}
4987+
4988+
// Match not(or(not(X),not(Y))) -> and(X, Y).
4989+
if (V.getOpcode() == ISD::OR && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
4990+
V.getOperand(0).hasOneUse() && V.getOperand(1).hasOneUse()) {
4991+
// TODO: Handle cases with single NOT operand -> VANDN
4992+
if (SDValue Op1 = isNOT(V.getOperand(1), DAG))
4993+
if (SDValue Op0 = isNOT(V.getOperand(0), DAG))
4994+
return DAG.getNode(ISD::AND, SDLoc(V), VT, DAG.getBitcast(VT, Op0),
4995+
DAG.getBitcast(VT, Op1));
4996+
}
4997+
4998+
// TODO: Add more matching patterns. Such as,
4999+
// not(concat_vectors(not(X), not(Y))) -> concat_vectors(X, Y).
5000+
// not(slt(C, X)) -> slt(X - 1, C)
5001+
5002+
return SDValue();
5003+
}
5004+
5005+
/// Try to fold: (and (xor X, -1), Y) -> (vandn X, Y).
5006+
static SDValue combineAndNotIntoVANDN(SDNode *N, const SDLoc &DL,
5007+
SelectionDAG &DAG) {
5008+
assert(N->getOpcode() == ISD::AND && "Unexpected opcode combine into ANDN");
5009+
5010+
MVT VT = N->getSimpleValueType(0);
5011+
if (!VT.is128BitVector() && !VT.is256BitVector())
5012+
return SDValue();
5013+
5014+
SDValue X, Y;
5015+
SDValue N0 = N->getOperand(0);
5016+
SDValue N1 = N->getOperand(1);
5017+
5018+
if (SDValue Not = isNOT(N0, DAG)) {
5019+
X = Not;
5020+
Y = N1;
5021+
} else if (SDValue Not = isNOT(N1, DAG)) {
5022+
X = Not;
5023+
Y = N0;
5024+
} else
5025+
return SDValue();
5026+
5027+
X = DAG.getBitcast(VT, X);
5028+
Y = DAG.getBitcast(VT, Y);
5029+
return DAG.getNode(LoongArchISD::VANDN, DL, VT, X, Y);
5030+
}
5031+
49425032
static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
49435033
TargetLowering::DAGCombinerInfo &DCI,
49445034
const LoongArchSubtarget &Subtarget) {
@@ -4960,6 +5050,9 @@ static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
49605050
if (!Subtarget.has32S())
49615051
return SDValue();
49625052

5053+
if (SDValue R = combineAndNotIntoVANDN(N, DL, DAG))
5054+
return R;
5055+
49635056
// Op's second operand must be a shifted mask.
49645057
if (!(CN = dyn_cast<ConstantSDNode>(SecondOperand)) ||
49655058
!isShiftedMask_64(CN->getZExtValue(), SMIdx, SMLen))
@@ -6628,6 +6721,65 @@ performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
66286721
return SDValue();
66296722
}
66306723

6724+
/// Do target-specific dag combines on LoongArchISD::VANDN nodes.
6725+
static SDValue performVANDNCombine(SDNode *N, SelectionDAG &DAG,
6726+
TargetLowering::DAGCombinerInfo &DCI,
6727+
const LoongArchSubtarget &Subtarget) {
6728+
SDValue N0 = N->getOperand(0);
6729+
SDValue N1 = N->getOperand(1);
6730+
MVT VT = N->getSimpleValueType(0);
6731+
SDLoc DL(N);
6732+
6733+
// VANDN(undef, x) -> 0
6734+
// VANDN(x, undef) -> 0
6735+
if (N0.isUndef() || N1.isUndef())
6736+
return DAG.getConstant(0, DL, VT);
6737+
6738+
// VANDN(0, x) -> x
6739+
if (ISD::isBuildVectorAllZeros(N0.getNode()))
6740+
return N1;
6741+
6742+
// VANDN(x, 0) -> 0
6743+
if (ISD::isBuildVectorAllZeros(N1.getNode()))
6744+
return DAG.getConstant(0, DL, VT);
6745+
6746+
// VANDN(x, -1) -> NOT(x) -> XOR(x, -1)
6747+
if (ISD::isBuildVectorAllOnes(N1.getNode()))
6748+
return DAG.getNOT(DL, N0, VT);
6749+
6750+
// Turn VANDN back to AND if input is inverted.
6751+
if (SDValue Not = isNOT(N0, DAG))
6752+
return DAG.getNode(ISD::AND, DL, VT, DAG.getBitcast(VT, Not), N1);
6753+
6754+
// Folds for better commutativity:
6755+
if (N1->hasOneUse()) {
6756+
// VANDN(x,NOT(y)) -> AND(NOT(x),NOT(y)) -> NOT(OR(X,Y)).
6757+
if (SDValue Not = isNOT(N1, DAG))
6758+
return DAG.getNOT(
6759+
DL, DAG.getNode(ISD::OR, DL, VT, N0, DAG.getBitcast(VT, Not)), VT);
6760+
6761+
// VANDN(x, SplatVector(Imm)) -> AND(NOT(x), NOT(SplatVector(~Imm)))
6762+
// -> NOT(OR(x, SplatVector(-Imm))
6763+
// Combination is performed only when VT is v16i8/v32i8, using `vnori.b` to
6764+
// gain benefits.
6765+
if (!DCI.isBeforeLegalizeOps() && (VT == MVT::v16i8 || VT == MVT::v32i8)) {
6766+
SDValue SplatValue;
6767+
if (isSplatOrUndef(N1.getNode(), SplatValue) &&
6768+
N1->isOnlyUserOf(SplatValue.getNode()))
6769+
if (auto *C = dyn_cast<ConstantSDNode>(SplatValue)) {
6770+
uint8_t NCVal = static_cast<uint8_t>(~(C->getSExtValue()));
6771+
SDValue Not =
6772+
DAG.getSplat(VT, DL, DAG.getTargetConstant(NCVal, DL, MVT::i8));
6773+
return DAG.getNOT(
6774+
DL, DAG.getNode(ISD::OR, DL, VT, N0, DAG.getBitcast(VT, Not)),
6775+
VT);
6776+
}
6777+
}
6778+
}
6779+
6780+
return SDValue();
6781+
}
6782+
66316783
SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
66326784
DAGCombinerInfo &DCI) const {
66336785
SelectionDAG &DAG = DCI.DAG;
@@ -6663,6 +6815,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
66636815
return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget);
66646816
case ISD::EXTRACT_VECTOR_ELT:
66656817
return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget);
6818+
case LoongArchISD::VANDN:
6819+
return performVANDNCombine(N, DAG, DCI, Subtarget);
66666820
}
66676821
return SDValue();
66686822
}
@@ -7454,6 +7608,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
74547608
NODE_NAME_CASE(VPICK_SEXT_ELT)
74557609
NODE_NAME_CASE(VPICK_ZEXT_ELT)
74567610
NODE_NAME_CASE(VREPLVE)
7611+
NODE_NAME_CASE(VANDN)
74577612
NODE_NAME_CASE(VALL_ZERO)
74587613
NODE_NAME_CASE(VANY_ZERO)
74597614
NODE_NAME_CASE(VALL_NONZERO)

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ enum NodeType : unsigned {
174174
VBSLL,
175175
VBSRL,
176176

177+
// Vector bit operation
178+
VANDN,
179+
177180
// Scalar load broadcast to vector
178181
VLDREPL,
179182

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def : Pat<(vnot (or (vt LASX256:$xj), (vt LASX256:$xk))),
13951395
(XVNOR_V LASX256:$xj, LASX256:$xk)>;
13961396
// XVANDN_V
13971397
foreach vt = [v32i8, v16i16, v8i32, v4i64] in
1398-
def : Pat<(and (vt (vnot LASX256:$xj)), (vt LASX256:$xk)),
1398+
def : Pat<(loongarch_vandn (vt LASX256:$xj), (vt LASX256:$xk)),
13991399
(XVANDN_V LASX256:$xj, LASX256:$xk)>;
14001400
// XVORN_V
14011401
foreach vt = [v32i8, v16i16, v8i32, v4i64] in
@@ -1449,25 +1449,25 @@ defm : PatXr<ctlz, "XVCLZ">;
14491449
defm : PatXr<ctpop, "XVPCNT">;
14501450

14511451
// XVBITCLR_{B/H/W/D}
1452-
def : Pat<(and v32i8:$xj, (vnot (shl vsplat_imm_eq_1, v32i8:$xk))),
1452+
def : Pat<(loongarch_vandn (v32i8 (shl vsplat_imm_eq_1, v32i8:$xk)), v32i8:$xj),
14531453
(v32i8 (XVBITCLR_B v32i8:$xj, v32i8:$xk))>;
1454-
def : Pat<(and v16i16:$xj, (vnot (shl vsplat_imm_eq_1, v16i16:$xk))),
1454+
def : Pat<(loongarch_vandn (v16i16 (shl vsplat_imm_eq_1, v16i16:$xk)), v16i16:$xj),
14551455
(v16i16 (XVBITCLR_H v16i16:$xj, v16i16:$xk))>;
1456-
def : Pat<(and v8i32:$xj, (vnot (shl vsplat_imm_eq_1, v8i32:$xk))),
1456+
def : Pat<(loongarch_vandn (v8i32 (shl vsplat_imm_eq_1, v8i32:$xk)), v8i32:$xj),
14571457
(v8i32 (XVBITCLR_W v8i32:$xj, v8i32:$xk))>;
1458-
def : Pat<(and v4i64:$xj, (vnot (shl vsplat_imm_eq_1, v4i64:$xk))),
1458+
def : Pat<(loongarch_vandn (v4i64 (shl vsplat_imm_eq_1, v4i64:$xk)), v4i64:$xj),
14591459
(v4i64 (XVBITCLR_D v4i64:$xj, v4i64:$xk))>;
1460-
def : Pat<(and v32i8:$xj, (vnot (shl vsplat_imm_eq_1,
1461-
(vsplati8imm7 v32i8:$xk)))),
1460+
def : Pat<(loongarch_vandn (v32i8 (shl vsplat_imm_eq_1,
1461+
(vsplati8imm7 v32i8:$xk))), v32i8:$xj),
14621462
(v32i8 (XVBITCLR_B v32i8:$xj, v32i8:$xk))>;
1463-
def : Pat<(and v16i16:$xj, (vnot (shl vsplat_imm_eq_1,
1464-
(vsplati16imm15 v16i16:$xk)))),
1463+
def : Pat<(loongarch_vandn (v16i16 (shl vsplat_imm_eq_1,
1464+
(vsplati16imm15 v16i16:$xk))), v16i16:$xj),
14651465
(v16i16 (XVBITCLR_H v16i16:$xj, v16i16:$xk))>;
1466-
def : Pat<(and v8i32:$xj, (vnot (shl vsplat_imm_eq_1,
1467-
(vsplati32imm31 v8i32:$xk)))),
1466+
def : Pat<(loongarch_vandn (v8i32 (shl vsplat_imm_eq_1,
1467+
(vsplati32imm31 v8i32:$xk))), v8i32:$xj),
14681468
(v8i32 (XVBITCLR_W v8i32:$xj, v8i32:$xk))>;
1469-
def : Pat<(and v4i64:$xj, (vnot (shl vsplat_imm_eq_1,
1470-
(vsplati64imm63 v4i64:$xk)))),
1469+
def : Pat<(loongarch_vandn (v4i64 (shl vsplat_imm_eq_1,
1470+
(vsplati64imm63 v4i64:$xk))), v4i64:$xj),
14711471
(v4i64 (XVBITCLR_D v4i64:$xj, v4i64:$xk))>;
14721472

14731473
// XVBITCLRI_{B/H/W/D}

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def loongarch_vpackev: SDNode<"LoongArchISD::VPACKEV", SDT_LoongArchV2R>;
5656
def loongarch_vpackod: SDNode<"LoongArchISD::VPACKOD", SDT_LoongArchV2R>;
5757
def loongarch_vilvl: SDNode<"LoongArchISD::VILVL", SDT_LoongArchV2R>;
5858
def loongarch_vilvh: SDNode<"LoongArchISD::VILVH", SDT_LoongArchV2R>;
59+
def loongarch_vandn: SDNode<"LoongArchISD::VANDN", SDT_LoongArchV2R>;
5960

6061
def loongarch_vshuf4i: SDNode<"LoongArchISD::VSHUF4I", SDT_LoongArchV1RUimm>;
6162
def loongarch_vshuf4i_d : SDNode<"LoongArchISD::VSHUF4I", SDT_LoongArchV2RUimm>;
@@ -1586,7 +1587,7 @@ def : Pat<(vnot (or (vt LSX128:$vj), (vt LSX128:$vk))),
15861587
(VNOR_V LSX128:$vj, LSX128:$vk)>;
15871588
// VANDN_V
15881589
foreach vt = [v16i8, v8i16, v4i32, v2i64] in
1589-
def : Pat<(and (vt (vnot LSX128:$vj)), (vt LSX128:$vk)),
1590+
def : Pat<(loongarch_vandn (vt LSX128:$vj), (vt LSX128:$vk)),
15901591
(VANDN_V LSX128:$vj, LSX128:$vk)>;
15911592
// VORN_V
15921593
foreach vt = [v16i8, v8i16, v4i32, v2i64] in
@@ -1640,25 +1641,25 @@ defm : PatVr<ctlz, "VCLZ">;
16401641
defm : PatVr<ctpop, "VPCNT">;
16411642

16421643
// VBITCLR_{B/H/W/D}
1643-
def : Pat<(and v16i8:$vj, (vnot (shl vsplat_imm_eq_1, v16i8:$vk))),
1644+
def : Pat<(loongarch_vandn (v16i8 (shl vsplat_imm_eq_1, v16i8:$vk)), v16i8:$vj),
16441645
(v16i8 (VBITCLR_B v16i8:$vj, v16i8:$vk))>;
1645-
def : Pat<(and v8i16:$vj, (vnot (shl vsplat_imm_eq_1, v8i16:$vk))),
1646+
def : Pat<(loongarch_vandn (v8i16 (shl vsplat_imm_eq_1, v8i16:$vk)), v8i16:$vj),
16461647
(v8i16 (VBITCLR_H v8i16:$vj, v8i16:$vk))>;
1647-
def : Pat<(and v4i32:$vj, (vnot (shl vsplat_imm_eq_1, v4i32:$vk))),
1648+
def : Pat<(loongarch_vandn (v4i32 (shl vsplat_imm_eq_1, v4i32:$vk)), v4i32:$vj),
16481649
(v4i32 (VBITCLR_W v4i32:$vj, v4i32:$vk))>;
1649-
def : Pat<(and v2i64:$vj, (vnot (shl vsplat_imm_eq_1, v2i64:$vk))),
1650+
def : Pat<(loongarch_vandn (v2i64 (shl vsplat_imm_eq_1, v2i64:$vk)), v2i64:$vj),
16501651
(v2i64 (VBITCLR_D v2i64:$vj, v2i64:$vk))>;
1651-
def : Pat<(and v16i8:$vj, (vnot (shl vsplat_imm_eq_1,
1652-
(vsplati8imm7 v16i8:$vk)))),
1652+
def : Pat<(loongarch_vandn (v16i8 (shl vsplat_imm_eq_1,
1653+
(vsplati8imm7 v16i8:$vk))), v16i8:$vj),
16531654
(v16i8 (VBITCLR_B v16i8:$vj, v16i8:$vk))>;
1654-
def : Pat<(and v8i16:$vj, (vnot (shl vsplat_imm_eq_1,
1655-
(vsplati16imm15 v8i16:$vk)))),
1655+
def : Pat<(loongarch_vandn (v8i16 (shl vsplat_imm_eq_1,
1656+
(vsplati16imm15 v8i16:$vk))), v8i16:$vj),
16561657
(v8i16 (VBITCLR_H v8i16:$vj, v8i16:$vk))>;
1657-
def : Pat<(and v4i32:$vj, (vnot (shl vsplat_imm_eq_1,
1658-
(vsplati32imm31 v4i32:$vk)))),
1658+
def : Pat<(loongarch_vandn (v4i32 (shl vsplat_imm_eq_1,
1659+
(vsplati32imm31 v4i32:$vk))), v4i32:$vj),
16591660
(v4i32 (VBITCLR_W v4i32:$vj, v4i32:$vk))>;
1660-
def : Pat<(and v2i64:$vj, (vnot (shl vsplat_imm_eq_1,
1661-
(vsplati64imm63 v2i64:$vk)))),
1661+
def : Pat<(loongarch_vandn (v2i64 (shl vsplat_imm_eq_1,
1662+
(vsplati64imm63 v2i64:$vk))), v2i64:$vj),
16621663
(v2i64 (VBITCLR_D v2i64:$vj, v2i64:$vk))>;
16631664

16641665
// VBITCLRI_{B/H/W/D}

0 commit comments

Comments
 (0)