Skip to content

Commit 95a4c9c

Browse files
authored
[LoongArch] Custom lower vecreduce_add. (#154304)
1 parent af0f85c commit 95a4c9c

File tree

8 files changed

+231
-72
lines changed

8 files changed

+231
-72
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
340340
{MVT::v16i8, MVT::v8i8, MVT::v4i8, MVT::v2i8, MVT::v8i16, MVT::v4i16,
341341
MVT::v2i16, MVT::v4i32, MVT::v2i32, MVT::v2i64}) {
342342
setOperationAction(ISD::TRUNCATE, VT, Custom);
343+
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
343344
}
344345
}
345346

@@ -377,6 +378,7 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
377378
setOperationAction(ISD::SCALAR_TO_VECTOR, VT, Custom);
378379
setOperationAction(ISD::ABDS, VT, Legal);
379380
setOperationAction(ISD::ABDU, VT, Legal);
381+
setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
380382
}
381383
for (MVT VT : {MVT::v32i8, MVT::v16i16, MVT::v8i32})
382384
setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -522,10 +524,62 @@ SDValue LoongArchTargetLowering::LowerOperation(SDValue Op,
522524
return lowerFP_TO_BF16(Op, DAG);
523525
case ISD::BF16_TO_FP:
524526
return lowerBF16_TO_FP(Op, DAG);
527+
case ISD::VECREDUCE_ADD:
528+
return lowerVECREDUCE_ADD(Op, DAG);
525529
}
526530
return SDValue();
527531
}
528532

533+
// Lower vecreduce_add using vhaddw instructions.
534+
// For Example:
535+
// call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
536+
// can be lowered to:
537+
// VHADDW_D_W vr0, vr0, vr0
538+
// VHADDW_Q_D vr0, vr0, vr0
539+
// VPICKVE2GR_D a0, vr0, 0
540+
// ADDI_W a0, a0, 0
541+
SDValue LoongArchTargetLowering::lowerVECREDUCE_ADD(SDValue Op,
542+
SelectionDAG &DAG) const {
543+
544+
SDLoc DL(Op);
545+
MVT OpVT = Op.getSimpleValueType();
546+
SDValue Val = Op.getOperand(0);
547+
548+
unsigned NumEles = Val.getSimpleValueType().getVectorNumElements();
549+
unsigned EleBits = Val.getSimpleValueType().getScalarSizeInBits();
550+
551+
unsigned LegalVecSize = 128;
552+
bool isLASX256Vector =
553+
Subtarget.hasExtLASX() && Val.getValueSizeInBits() == 256;
554+
555+
// Ensure operand type legal or enable it legal.
556+
while (!isTypeLegal(Val.getSimpleValueType())) {
557+
Val = DAG.WidenVector(Val, DL);
558+
}
559+
560+
// NumEles is designed for iterations count, v4i32 for LSX
561+
// and v8i32 for LASX should have the same count.
562+
if (isLASX256Vector) {
563+
NumEles /= 2;
564+
LegalVecSize = 256;
565+
}
566+
567+
for (unsigned i = 1; i < NumEles; i *= 2, EleBits *= 2) {
568+
MVT IntTy = MVT::getIntegerVT(EleBits);
569+
MVT VecTy = MVT::getVectorVT(IntTy, LegalVecSize / EleBits);
570+
Val = DAG.getNode(LoongArchISD::VHADDW, DL, VecTy, Val, Val);
571+
}
572+
573+
if (isLASX256Vector) {
574+
SDValue Tmp = DAG.getNode(LoongArchISD::XVPERMI, DL, MVT::v4i64, Val,
575+
DAG.getConstant(2, DL, MVT::i64));
576+
Val = DAG.getNode(ISD::ADD, DL, MVT::v4i64, Tmp, Val);
577+
}
578+
579+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Val,
580+
DAG.getConstant(0, DL, Subtarget.getGRLenVT()));
581+
}
582+
529583
SDValue LoongArchTargetLowering::lowerPREFETCH(SDValue Op,
530584
SelectionDAG &DAG) const {
531585
unsigned IsData = Op.getConstantOperandVal(4);
@@ -6659,6 +6713,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
66596713
NODE_NAME_CASE(XVMSKGEZ)
66606714
NODE_NAME_CASE(XVMSKEQZ)
66616715
NODE_NAME_CASE(XVMSKNEZ)
6716+
NODE_NAME_CASE(VHADDW)
66626717
}
66636718
#undef NODE_NAME_CASE
66646719
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ enum NodeType : unsigned {
177177
XVMSKEQZ,
178178
XVMSKNEZ,
179179

180+
// Vector Horizontal Addition with Widening‌
181+
VHADDW
182+
180183
// Intrinsic operations end =============================================
181184
};
182185
} // end namespace LoongArchISD
@@ -386,6 +389,7 @@ class LoongArchTargetLowering : public TargetLowering {
386389
SDValue lowerFP16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
387390
SDValue lowerFP_TO_BF16(SDValue Op, SelectionDAG &DAG) const;
388391
SDValue lowerBF16_TO_FP(SDValue Op, SelectionDAG &DAG) const;
392+
SDValue lowerVECREDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
389393

390394
bool isFPImmLegal(const APFloat &Imm, EVT VT,
391395
bool ForCodeSize) const override;

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,6 +1186,17 @@ multiclass PatXrXrXr<SDPatternOperator OpNode, string Inst> {
11861186
(!cast<LAInst>(Inst#"_D") LASX256:$xd, LASX256:$xj, LASX256:$xk)>;
11871187
}
11881188

1189+
multiclass PatXrXrW<SDPatternOperator OpNode, string Inst> {
1190+
def : Pat<(OpNode(v32i8 LASX256:$vj), (v32i8 LASX256:$vk)),
1191+
(!cast<LAInst>(Inst#"_H_B") LASX256:$vj, LASX256:$vk)>;
1192+
def : Pat<(OpNode(v16i16 LASX256:$vj), (v16i16 LASX256:$vk)),
1193+
(!cast<LAInst>(Inst#"_W_H") LASX256:$vj, LASX256:$vk)>;
1194+
def : Pat<(OpNode(v8i32 LASX256:$vj), (v8i32 LASX256:$vk)),
1195+
(!cast<LAInst>(Inst#"_D_W") LASX256:$vj, LASX256:$vk)>;
1196+
def : Pat<(OpNode(v4i64 LASX256:$vj), (v4i64 LASX256:$vk)),
1197+
(!cast<LAInst>(Inst#"_Q_D") LASX256:$vj, LASX256:$vk)>;
1198+
}
1199+
11891200
multiclass PatShiftXrXr<SDPatternOperator OpNode, string Inst> {
11901201
def : Pat<(OpNode (v32i8 LASX256:$xj), (and vsplati8_imm_eq_7,
11911202
(v32i8 LASX256:$xk))),
@@ -1513,6 +1524,9 @@ def : Pat<(bswap (v8i32 LASX256:$xj)), (XVSHUF4I_B LASX256:$xj, 0b00011011)>;
15131524
def : Pat<(bswap (v4i64 LASX256:$xj)),
15141525
(XVSHUF4I_W (XVSHUF4I_B LASX256:$xj, 0b00011011), 0b10110001)>;
15151526

1527+
// XVHADDW_{H_B/W_H/D_W/Q_D}
1528+
defm : PatXrXrW<loongarch_vhaddw, "XVHADDW">;
1529+
15161530
// XVFADD_{S/D}
15171531
defm : PatXrXrF<fadd, "XVFADD">;
15181532

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
7171
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
7272
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
7373

74+
def loongarch_vhaddw : SDNode<"LoongArchISD::VHADDW", SDT_LoongArchV2R>;
75+
7476
def loongarch_vldrepl
7577
: SDNode<"LoongArchISD::VLDREPL",
7678
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
@@ -1364,6 +1366,17 @@ multiclass PatVrVrVr<SDPatternOperator OpNode, string Inst> {
13641366
(!cast<LAInst>(Inst#"_D") LSX128:$vd, LSX128:$vj, LSX128:$vk)>;
13651367
}
13661368

1369+
multiclass PatVrVrW<SDPatternOperator OpNode, string Inst> {
1370+
def : Pat<(OpNode(v16i8 LSX128:$vj), (v16i8 LSX128:$vk)),
1371+
(!cast<LAInst>(Inst#"_H_B") LSX128:$vj, LSX128:$vk)>;
1372+
def : Pat<(OpNode(v8i16 LSX128:$vj), (v8i16 LSX128:$vk)),
1373+
(!cast<LAInst>(Inst#"_W_H") LSX128:$vj, LSX128:$vk)>;
1374+
def : Pat<(OpNode(v4i32 LSX128:$vj), (v4i32 LSX128:$vk)),
1375+
(!cast<LAInst>(Inst#"_D_W") LSX128:$vj, LSX128:$vk)>;
1376+
def : Pat<(OpNode(v2i64 LSX128:$vj), (v2i64 LSX128:$vk)),
1377+
(!cast<LAInst>(Inst#"_Q_D") LSX128:$vj, LSX128:$vk)>;
1378+
}
1379+
13671380
multiclass PatShiftVrVr<SDPatternOperator OpNode, string Inst> {
13681381
def : Pat<(OpNode (v16i8 LSX128:$vj), (and vsplati8_imm_eq_7,
13691382
(v16i8 LSX128:$vk))),
@@ -1709,6 +1722,9 @@ def : Pat<(bswap (v4i32 LSX128:$vj)), (VSHUF4I_B LSX128:$vj, 0b00011011)>;
17091722
def : Pat<(bswap (v2i64 LSX128:$vj)),
17101723
(VSHUF4I_W (VSHUF4I_B LSX128:$vj, 0b00011011), 0b10110001)>;
17111724

1725+
// VHADDW_{H_B/W_H/D_W/Q_D}
1726+
defm : PatVrVrW<loongarch_vhaddw, "VHADDW">;
1727+
17121728
// VFADD_{S/D}
17131729
defm : PatVrVrF<fadd, "VFADD">;
17141730

llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,13 @@ unsigned LoongArchTTIImpl::getPrefetchDistance() const { return 200; }
9595

9696
bool LoongArchTTIImpl::enableWritePrefetching() const { return true; }
9797

98+
bool LoongArchTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
99+
switch (II->getIntrinsicID()) {
100+
default:
101+
return true;
102+
case Intrinsic::vector_reduce_add:
103+
return false;
104+
}
105+
}
106+
98107
// TODO: Implement more hooks to provide TTI machinery for LoongArch.

llvm/lib/Target/LoongArch/LoongArchTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class LoongArchTTIImpl : public BasicTTIImplBase<LoongArchTTIImpl> {
5353
unsigned getPrefetchDistance() const override;
5454
bool enableWritePrefetching() const override;
5555

56+
bool shouldExpandReduction(const IntrinsicInst *II) const override;
57+
5658
// TODO: Implement more hooks to provide TTI machinery for LoongArch.
5759
};
5860

Lines changed: 24 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,18 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
32
; RUN: llc --mtriple=loongarch64 --mattr=+lasx %s -o - | FileCheck %s
43

54
define void @vec_reduce_add_v32i8(ptr %src, ptr %dst) nounwind {
65
; CHECK-LABEL: vec_reduce_add_v32i8:
76
; CHECK: # %bb.0:
87
; CHECK-NEXT: xvld $xr0, $a0, 0
9-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 78
10-
; CHECK-NEXT: xvshuf4i.b $xr1, $xr1, 228
11-
; CHECK-NEXT: xvadd.b $xr0, $xr0, $xr1
12-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
13-
; CHECK-NEXT: xvbsrl.v $xr1, $xr1, 8
14-
; CHECK-NEXT: xvadd.b $xr0, $xr0, $xr1
15-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
16-
; CHECK-NEXT: xvsrli.d $xr1, $xr1, 32
17-
; CHECK-NEXT: xvadd.b $xr0, $xr0, $xr1
18-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
19-
; CHECK-NEXT: xvshuf4i.b $xr1, $xr1, 14
20-
; CHECK-NEXT: xvadd.b $xr0, $xr0, $xr1
21-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
22-
; CHECK-NEXT: xvrepl128vei.b $xr1, $xr1, 1
23-
; CHECK-NEXT: xvadd.b $xr0, $xr0, $xr1
24-
; CHECK-NEXT: xvstelm.b $xr0, $a1, 0, 0
8+
; CHECK-NEXT: xvhaddw.h.b $xr0, $xr0, $xr0
9+
; CHECK-NEXT: xvhaddw.w.h $xr0, $xr0, $xr0
10+
; CHECK-NEXT: xvhaddw.d.w $xr0, $xr0, $xr0
11+
; CHECK-NEXT: xvhaddw.q.d $xr0, $xr0, $xr0
12+
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 2
13+
; CHECK-NEXT: xvadd.d $xr0, $xr1, $xr0
14+
; CHECK-NEXT: xvpickve2gr.d $a0, $xr0, 0
15+
; CHECK-NEXT: st.b $a0, $a1, 0
2516
; CHECK-NEXT: ret
2617
%v = load <32 x i8>, ptr %src
2718
%res = call i8 @llvm.vector.reduce.add.v32i8(<32 x i8> %v)
@@ -33,19 +24,13 @@ define void @vec_reduce_add_v16i16(ptr %src, ptr %dst) nounwind {
3324
; CHECK-LABEL: vec_reduce_add_v16i16:
3425
; CHECK: # %bb.0:
3526
; CHECK-NEXT: xvld $xr0, $a0, 0
36-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 78
37-
; CHECK-NEXT: xvshuf4i.h $xr1, $xr1, 228
38-
; CHECK-NEXT: xvadd.h $xr0, $xr0, $xr1
39-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
40-
; CHECK-NEXT: xvbsrl.v $xr1, $xr1, 8
41-
; CHECK-NEXT: xvadd.h $xr0, $xr0, $xr1
42-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
43-
; CHECK-NEXT: xvshuf4i.h $xr1, $xr1, 14
44-
; CHECK-NEXT: xvadd.h $xr0, $xr0, $xr1
45-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
46-
; CHECK-NEXT: xvrepl128vei.h $xr1, $xr1, 1
47-
; CHECK-NEXT: xvadd.h $xr0, $xr0, $xr1
48-
; CHECK-NEXT: xvstelm.h $xr0, $a1, 0, 0
27+
; CHECK-NEXT: xvhaddw.w.h $xr0, $xr0, $xr0
28+
; CHECK-NEXT: xvhaddw.d.w $xr0, $xr0, $xr0
29+
; CHECK-NEXT: xvhaddw.q.d $xr0, $xr0, $xr0
30+
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 2
31+
; CHECK-NEXT: xvadd.d $xr0, $xr1, $xr0
32+
; CHECK-NEXT: xvpickve2gr.d $a0, $xr0, 0
33+
; CHECK-NEXT: st.h $a0, $a1, 0
4934
; CHECK-NEXT: ret
5035
%v = load <16 x i16>, ptr %src
5136
%res = call i16 @llvm.vector.reduce.add.v16i16(<16 x i16> %v)
@@ -57,16 +42,12 @@ define void @vec_reduce_add_v8i32(ptr %src, ptr %dst) nounwind {
5742
; CHECK-LABEL: vec_reduce_add_v8i32:
5843
; CHECK: # %bb.0:
5944
; CHECK-NEXT: xvld $xr0, $a0, 0
60-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 78
61-
; CHECK-NEXT: xvshuf4i.w $xr1, $xr1, 228
62-
; CHECK-NEXT: xvadd.w $xr0, $xr0, $xr1
63-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
64-
; CHECK-NEXT: xvshuf4i.w $xr1, $xr1, 14
65-
; CHECK-NEXT: xvadd.w $xr0, $xr0, $xr1
66-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
67-
; CHECK-NEXT: xvrepl128vei.w $xr1, $xr1, 1
68-
; CHECK-NEXT: xvadd.w $xr0, $xr0, $xr1
69-
; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
45+
; CHECK-NEXT: xvhaddw.d.w $xr0, $xr0, $xr0
46+
; CHECK-NEXT: xvhaddw.q.d $xr0, $xr0, $xr0
47+
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 2
48+
; CHECK-NEXT: xvadd.d $xr0, $xr1, $xr0
49+
; CHECK-NEXT: xvpickve2gr.d $a0, $xr0, 0
50+
; CHECK-NEXT: st.w $a0, $a1, 0
7051
; CHECK-NEXT: ret
7152
%v = load <8 x i32>, ptr %src
7253
%res = call i32 @llvm.vector.reduce.add.v8i32(<8 x i32> %v)
@@ -78,19 +59,13 @@ define void @vec_reduce_add_v4i64(ptr %src, ptr %dst) nounwind {
7859
; CHECK-LABEL: vec_reduce_add_v4i64:
7960
; CHECK: # %bb.0:
8061
; CHECK-NEXT: xvld $xr0, $a0, 0
81-
; CHECK-NEXT: pcalau12i $a0, %pc_hi20(.LCPI3_0)
82-
; CHECK-NEXT: xvld $xr1, $a0, %pc_lo12(.LCPI3_0)
83-
; CHECK-NEXT: xvpermi.d $xr2, $xr0, 78
84-
; CHECK-NEXT: xvshuf.d $xr1, $xr0, $xr2
85-
; CHECK-NEXT: xvadd.d $xr0, $xr0, $xr1
86-
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 68
87-
; CHECK-NEXT: xvrepl128vei.d $xr1, $xr1, 1
88-
; CHECK-NEXT: xvadd.d $xr0, $xr0, $xr1
62+
; CHECK-NEXT: xvhaddw.q.d $xr0, $xr0, $xr0
63+
; CHECK-NEXT: xvpermi.d $xr1, $xr0, 2
64+
; CHECK-NEXT: xvadd.d $xr0, $xr1, $xr0
8965
; CHECK-NEXT: xvstelm.d $xr0, $a1, 0, 0
9066
; CHECK-NEXT: ret
9167
%v = load <4 x i64>, ptr %src
9268
%res = call i64 @llvm.vector.reduce.add.v4i64(<4 x i64> %v)
9369
store i64 %res, ptr %dst
9470
ret void
9571
}
96-

0 commit comments

Comments
 (0)