Skip to content

Commit c507848

Browse files
authored
[LoongArch] Optimize extractelement containing variable index for lasx (#151475)
Ideas suggested by: @heiher @tangaac
1 parent dc170c7 commit c507848

File tree

4 files changed

+125
-88
lines changed

4 files changed

+125
-88
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,11 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
423423
setTargetDAGCombine(ISD::BITCAST);
424424
}
425425

426+
// Set DAG combine for 'LASX' feature.
427+
428+
if (Subtarget.hasExtLASX())
429+
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
430+
426431
// Compute derived properties from the register classes.
427432
computeRegisterProperties(Subtarget.getRegisterInfo());
428433

@@ -2778,14 +2783,58 @@ SDValue LoongArchTargetLowering::lowerCONCAT_VECTORS(SDValue Op,
27782783
SDValue
27792784
LoongArchTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
27802785
SelectionDAG &DAG) const {
2781-
EVT VecTy = Op->getOperand(0)->getValueType(0);
2786+
MVT EltVT = Op.getSimpleValueType();
2787+
SDValue Vec = Op->getOperand(0);
2788+
EVT VecTy = Vec->getValueType(0);
27822789
SDValue Idx = Op->getOperand(1);
2783-
unsigned NumElts = VecTy.getVectorNumElements();
2790+
SDLoc DL(Op);
2791+
MVT GRLenVT = Subtarget.getGRLenVT();
2792+
2793+
assert(VecTy.is256BitVector() && "Unexpected EXTRACT_VECTOR_ELT vector type");
27842794

2785-
if (isa<ConstantSDNode>(Idx) && Idx->getAsZExtVal() < NumElts)
2795+
if (isa<ConstantSDNode>(Idx))
27862796
return Op;
27872797

2788-
return SDValue();
2798+
switch (VecTy.getSimpleVT().SimpleTy) {
2799+
default:
2800+
llvm_unreachable("Unexpected type");
2801+
case MVT::v32i8:
2802+
case MVT::v16i16:
2803+
case MVT::v4i64:
2804+
case MVT::v4f64: {
2805+
// Extract the high half subvector and place it to the low half of a new
2806+
// vector. It doesn't matter what the high half of the new vector is.
2807+
EVT HalfTy = VecTy.getHalfNumVectorElementsVT(*DAG.getContext());
2808+
SDValue VecHi =
2809+
DAG.getExtractSubvector(DL, HalfTy, Vec, HalfTy.getVectorNumElements());
2810+
SDValue TmpVec =
2811+
DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecTy, DAG.getUNDEF(VecTy),
2812+
VecHi, DAG.getConstant(0, DL, GRLenVT));
2813+
2814+
// Shuffle the origin Vec and the TmpVec using MaskVec, the lowest element
2815+
// of MaskVec is Idx, the rest do not matter. ResVec[0] will hold the
2816+
// desired element.
2817+
SDValue IdxCp =
2818+
DAG.getNode(LoongArchISD::MOVGR2FR_W_LA64, DL, MVT::f32, Idx);
2819+
SDValue IdxVec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f32, IdxCp);
2820+
SDValue MaskVec =
2821+
DAG.getBitcast((VecTy == MVT::v4f64) ? MVT::v4i64 : VecTy, IdxVec);
2822+
SDValue ResVec =
2823+
DAG.getNode(LoongArchISD::VSHUF, DL, VecTy, MaskVec, TmpVec, Vec);
2824+
2825+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, ResVec,
2826+
DAG.getConstant(0, DL, GRLenVT));
2827+
}
2828+
case MVT::v8i32:
2829+
case MVT::v8f32: {
2830+
SDValue SplatIdx = DAG.getSplatBuildVector(MVT::v8i32, DL, Idx);
2831+
SDValue SplatValue =
2832+
DAG.getNode(LoongArchISD::XVPERM, DL, VecTy, Vec, SplatIdx);
2833+
2834+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, SplatValue,
2835+
DAG.getConstant(0, DL, GRLenVT));
2836+
}
2837+
}
27892838
}
27902839

27912840
SDValue
@@ -6152,6 +6201,42 @@ performSPLIT_PAIR_F64Combine(SDNode *N, SelectionDAG &DAG,
61526201
return SDValue();
61536202
}
61546203

6204+
static SDValue
6205+
performEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
6206+
TargetLowering::DAGCombinerInfo &DCI,
6207+
const LoongArchSubtarget &Subtarget) {
6208+
if (!DCI.isBeforeLegalize())
6209+
return SDValue();
6210+
6211+
MVT EltVT = N->getSimpleValueType(0);
6212+
SDValue Vec = N->getOperand(0);
6213+
EVT VecTy = Vec->getValueType(0);
6214+
SDValue Idx = N->getOperand(1);
6215+
unsigned IdxOp = Idx.getOpcode();
6216+
SDLoc DL(N);
6217+
6218+
if (!VecTy.is256BitVector() || isa<ConstantSDNode>(Idx))
6219+
return SDValue();
6220+
6221+
// Combine:
6222+
// t2 = truncate t1
6223+
// t3 = {zero/sign/any}_extend t2
6224+
// t4 = extract_vector_elt t0, t3
6225+
// to:
6226+
// t4 = extract_vector_elt t0, t1
6227+
if (IdxOp == ISD::ZERO_EXTEND || IdxOp == ISD::SIGN_EXTEND ||
6228+
IdxOp == ISD::ANY_EXTEND) {
6229+
SDValue IdxOrig = Idx.getOperand(0);
6230+
if (!(IdxOrig.getOpcode() == ISD::TRUNCATE))
6231+
return SDValue();
6232+
6233+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Vec,
6234+
IdxOrig.getOperand(0));
6235+
}
6236+
6237+
return SDValue();
6238+
}
6239+
61556240
SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
61566241
DAGCombinerInfo &DCI) const {
61576242
SelectionDAG &DAG = DCI.DAG;
@@ -6185,6 +6270,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
61856270
return performVMSKLTZCombine(N, DAG, DCI, Subtarget);
61866271
case LoongArchISD::SPLIT_PAIR_F64:
61876272
return performSPLIT_PAIR_F64Combine(N, DAG, DCI, Subtarget);
6273+
case ISD::EXTRACT_VECTOR_ELT:
6274+
return performEXTRACT_VECTOR_ELTCombine(N, DAG, DCI, Subtarget);
61886275
}
61896276
return SDValue();
61906277
}
@@ -6967,6 +7054,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
69677054
NODE_NAME_CASE(VREPLVEI)
69687055
NODE_NAME_CASE(VREPLGR2VR)
69697056
NODE_NAME_CASE(XVPERMI)
7057+
NODE_NAME_CASE(XVPERM)
69707058
NODE_NAME_CASE(VPICK_SEXT_ELT)
69717059
NODE_NAME_CASE(VPICK_ZEXT_ELT)
69727060
NODE_NAME_CASE(VREPLVE)

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ enum NodeType : unsigned {
145145
VREPLVEI,
146146
VREPLGR2VR,
147147
XVPERMI,
148+
XVPERM,
148149

149150
// Extended vector element extraction
150151
VPICK_SEXT_ELT,

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
def SDT_LoongArchXVPERM : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisSameAs<0, 1>,
14+
SDTCisVec<2>, SDTCisInt<2>]>;
15+
1316
// Target nodes.
1417
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
18+
def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>;
1519
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
1620
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
1721
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
@@ -1866,6 +1870,12 @@ def : Pat<(loongarch_xvpermi v4i64:$xj, immZExt8: $ui8),
18661870
def : Pat<(loongarch_xvpermi v4f64:$xj, immZExt8: $ui8),
18671871
(XVPERMI_D v4f64:$xj, immZExt8: $ui8)>;
18681872

1873+
// XVPERM_W
1874+
def : Pat<(loongarch_xvperm v8i32:$xj, v8i32:$xk),
1875+
(XVPERM_W v8i32:$xj, v8i32:$xk)>;
1876+
def : Pat<(loongarch_xvperm v8f32:$xj, v8i32:$xk),
1877+
(XVPERM_W v8f32:$xj, v8i32:$xk)>;
1878+
18691879
// XVREPLVE0_{W/D}
18701880
def : Pat<(lasxsplatf32 FPR32:$fj),
18711881
(XVREPLVE0_W (SUBREG_TO_REG (i64 0), FPR32:$fj, sub_32))>;

llvm/test/CodeGen/LoongArch/lasx/ir-instruction/extractelement.ll

Lines changed: 22 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,11 @@ define void @extract_4xdouble(ptr %src, ptr %dst) nounwind {
7676
define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
7777
; CHECK-LABEL: extract_32xi8_idx:
7878
; CHECK: # %bb.0:
79-
; CHECK-NEXT: addi.d $sp, $sp, -96
80-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
81-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
82-
; CHECK-NEXT: addi.d $fp, $sp, 96
83-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
8479
; CHECK-NEXT: xvld $xr0, $a0, 0
85-
; CHECK-NEXT: xvst $xr0, $sp, 32
86-
; CHECK-NEXT: addi.d $a0, $sp, 32
87-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 0
88-
; CHECK-NEXT: ld.b $a0, $a0, 0
89-
; CHECK-NEXT: st.b $a0, $a1, 0
90-
; CHECK-NEXT: addi.d $sp, $fp, -96
91-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
92-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
93-
; CHECK-NEXT: addi.d $sp, $sp, 96
80+
; CHECK-NEXT: xvpermi.q $xr1, $xr0, 1
81+
; CHECK-NEXT: movgr2fr.w $fa2, $a2
82+
; CHECK-NEXT: xvshuf.b $xr0, $xr1, $xr0, $xr2
83+
; CHECK-NEXT: xvstelm.b $xr0, $a1, 0, 0
9484
; CHECK-NEXT: ret
9585
%v = load volatile <32 x i8>, ptr %src
9686
%e = extractelement <32 x i8> %v, i32 %idx
@@ -101,21 +91,11 @@ define void @extract_32xi8_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
10191
define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
10292
; CHECK-LABEL: extract_16xi16_idx:
10393
; CHECK: # %bb.0:
104-
; CHECK-NEXT: addi.d $sp, $sp, -96
105-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
106-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
107-
; CHECK-NEXT: addi.d $fp, $sp, 96
108-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
10994
; CHECK-NEXT: xvld $xr0, $a0, 0
110-
; CHECK-NEXT: xvst $xr0, $sp, 32
111-
; CHECK-NEXT: addi.d $a0, $sp, 32
112-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 1
113-
; CHECK-NEXT: ld.h $a0, $a0, 0
114-
; CHECK-NEXT: st.h $a0, $a1, 0
115-
; CHECK-NEXT: addi.d $sp, $fp, -96
116-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
117-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
118-
; CHECK-NEXT: addi.d $sp, $sp, 96
95+
; CHECK-NEXT: xvpermi.q $xr1, $xr0, 1
96+
; CHECK-NEXT: movgr2fr.w $fa2, $a2
97+
; CHECK-NEXT: xvshuf.h $xr2, $xr1, $xr0
98+
; CHECK-NEXT: xvstelm.h $xr2, $a1, 0, 0
11999
; CHECK-NEXT: ret
120100
%v = load volatile <16 x i16>, ptr %src
121101
%e = extractelement <16 x i16> %v, i32 %idx
@@ -126,21 +106,10 @@ define void @extract_16xi16_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
126106
define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
127107
; CHECK-LABEL: extract_8xi32_idx:
128108
; CHECK: # %bb.0:
129-
; CHECK-NEXT: addi.d $sp, $sp, -96
130-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
131-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
132-
; CHECK-NEXT: addi.d $fp, $sp, 96
133-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
134109
; CHECK-NEXT: xvld $xr0, $a0, 0
135-
; CHECK-NEXT: xvst $xr0, $sp, 32
136-
; CHECK-NEXT: addi.d $a0, $sp, 32
137-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
138-
; CHECK-NEXT: ld.w $a0, $a0, 0
139-
; CHECK-NEXT: st.w $a0, $a1, 0
140-
; CHECK-NEXT: addi.d $sp, $fp, -96
141-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
142-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
143-
; CHECK-NEXT: addi.d $sp, $sp, 96
110+
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
111+
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
112+
; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
144113
; CHECK-NEXT: ret
145114
%v = load volatile <8 x i32>, ptr %src
146115
%e = extractelement <8 x i32> %v, i32 %idx
@@ -151,21 +120,11 @@ define void @extract_8xi32_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
151120
define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
152121
; CHECK-LABEL: extract_4xi64_idx:
153122
; CHECK: # %bb.0:
154-
; CHECK-NEXT: addi.d $sp, $sp, -96
155-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
156-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
157-
; CHECK-NEXT: addi.d $fp, $sp, 96
158-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
159123
; CHECK-NEXT: xvld $xr0, $a0, 0
160-
; CHECK-NEXT: xvst $xr0, $sp, 32
161-
; CHECK-NEXT: addi.d $a0, $sp, 32
162-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3
163-
; CHECK-NEXT: ld.d $a0, $a0, 0
164-
; CHECK-NEXT: st.d $a0, $a1, 0
165-
; CHECK-NEXT: addi.d $sp, $fp, -96
166-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
167-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
168-
; CHECK-NEXT: addi.d $sp, $sp, 96
124+
; CHECK-NEXT: xvpermi.q $xr1, $xr0, 1
125+
; CHECK-NEXT: movgr2fr.w $fa2, $a2
126+
; CHECK-NEXT: xvshuf.d $xr2, $xr1, $xr0
127+
; CHECK-NEXT: xvstelm.d $xr2, $a1, 0, 0
169128
; CHECK-NEXT: ret
170129
%v = load volatile <4 x i64>, ptr %src
171130
%e = extractelement <4 x i64> %v, i32 %idx
@@ -176,21 +135,10 @@ define void @extract_4xi64_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
176135
define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
177136
; CHECK-LABEL: extract_8xfloat_idx:
178137
; CHECK: # %bb.0:
179-
; CHECK-NEXT: addi.d $sp, $sp, -96
180-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
181-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
182-
; CHECK-NEXT: addi.d $fp, $sp, 96
183-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
184138
; CHECK-NEXT: xvld $xr0, $a0, 0
185-
; CHECK-NEXT: xvst $xr0, $sp, 32
186-
; CHECK-NEXT: addi.d $a0, $sp, 32
187-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 2
188-
; CHECK-NEXT: fld.s $fa0, $a0, 0
189-
; CHECK-NEXT: fst.s $fa0, $a1, 0
190-
; CHECK-NEXT: addi.d $sp, $fp, -96
191-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
192-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
193-
; CHECK-NEXT: addi.d $sp, $sp, 96
139+
; CHECK-NEXT: xvreplgr2vr.w $xr1, $a2
140+
; CHECK-NEXT: xvperm.w $xr0, $xr0, $xr1
141+
; CHECK-NEXT: xvstelm.w $xr0, $a1, 0, 0
194142
; CHECK-NEXT: ret
195143
%v = load volatile <8 x float>, ptr %src
196144
%e = extractelement <8 x float> %v, i32 %idx
@@ -201,21 +149,11 @@ define void @extract_8xfloat_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
201149
define void @extract_4xdouble_idx(ptr %src, ptr %dst, i32 %idx) nounwind {
202150
; CHECK-LABEL: extract_4xdouble_idx:
203151
; CHECK: # %bb.0:
204-
; CHECK-NEXT: addi.d $sp, $sp, -96
205-
; CHECK-NEXT: st.d $ra, $sp, 88 # 8-byte Folded Spill
206-
; CHECK-NEXT: st.d $fp, $sp, 80 # 8-byte Folded Spill
207-
; CHECK-NEXT: addi.d $fp, $sp, 96
208-
; CHECK-NEXT: bstrins.d $sp, $zero, 4, 0
209152
; CHECK-NEXT: xvld $xr0, $a0, 0
210-
; CHECK-NEXT: xvst $xr0, $sp, 32
211-
; CHECK-NEXT: addi.d $a0, $sp, 32
212-
; CHECK-NEXT: bstrins.d $a0, $a2, 4, 3
213-
; CHECK-NEXT: fld.d $fa0, $a0, 0
214-
; CHECK-NEXT: fst.d $fa0, $a1, 0
215-
; CHECK-NEXT: addi.d $sp, $fp, -96
216-
; CHECK-NEXT: ld.d $fp, $sp, 80 # 8-byte Folded Reload
217-
; CHECK-NEXT: ld.d $ra, $sp, 88 # 8-byte Folded Reload
218-
; CHECK-NEXT: addi.d $sp, $sp, 96
153+
; CHECK-NEXT: xvpermi.q $xr1, $xr0, 1
154+
; CHECK-NEXT: movgr2fr.w $fa2, $a2
155+
; CHECK-NEXT: xvshuf.d $xr2, $xr1, $xr0
156+
; CHECK-NEXT: xvstelm.d $xr2, $a1, 0, 0
219157
; CHECK-NEXT: ret
220158
%v = load volatile <4 x double>, ptr %src
221159
%e = extractelement <4 x double> %v, i32 %idx

0 commit comments

Comments
 (0)