Skip to content

Commit d8b97d2

Browse files
committed
Lower build_vector to broadcast load if possible
1 parent d149631 commit d8b97d2

File tree

5 files changed

+80
-4
lines changed

5 files changed

+80
-4
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,6 +1876,47 @@ static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
18761876
return false;
18771877
}
18781878

1879+
// Lower BUILD_VECTOR as broadcast load (if possible).
1880+
// For example:
1881+
// %a = load i8, ptr %ptr
1882+
// %b = build_vector %a, %a, %a, %a
1883+
// is lowered to :
1884+
// (VLDREPL_B $a0, 0)
1885+
static SDValue lowerBUILD_VECTORAsBroadCastLoad(BuildVectorSDNode *BVOp,
1886+
const SDLoc &DL,
1887+
SelectionDAG &DAG) {
1888+
MVT VT = BVOp->getSimpleValueType(0);
1889+
int NumOps = BVOp->getNumOperands();
1890+
1891+
assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) &&
1892+
"Unsupported vector type for broadcast.");
1893+
1894+
SDValue IdentitySrc;
1895+
bool IsIdeneity = true;
1896+
1897+
for (int i = 0; i != NumOps; i++) {
1898+
SDValue Op = BVOp->getOperand(i);
1899+
if (Op.getOpcode() != ISD::LOAD || (IdentitySrc && Op != IdentitySrc)) {
1900+
IsIdeneity = false;
1901+
break;
1902+
}
1903+
IdentitySrc = BVOp->getOperand(0);
1904+
}
1905+
1906+
if (IsIdeneity) {
1907+
auto *LN = cast<LoadSDNode>(IdentitySrc);
1908+
SDVTList Tys =
1909+
LN->isIndexed()
1910+
? DAG.getVTList(VT, LN->getBasePtr().getValueType(), MVT::Other)
1911+
: DAG.getVTList(VT, MVT::Other);
1912+
SDValue Ops[] = {LN->getChain(), LN->getBasePtr(), LN->getOffset()};
1913+
SDValue BCast = DAG.getNode(LoongArchISD::VLDREPL, DL, Tys, Ops);
1914+
DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BCast.getValue(1));
1915+
return BCast;
1916+
}
1917+
return SDValue();
1918+
}
1919+
18791920
SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
18801921
SelectionDAG &DAG) const {
18811922
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
@@ -1891,6 +1932,9 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
18911932
(!Subtarget.hasExtLASX() || !Is256Vec))
18921933
return SDValue();
18931934

1935+
if (SDValue Result = lowerBUILD_VECTORAsBroadCastLoad(Node, DL, DAG))
1936+
return Result;
1937+
18941938
if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize, HasAnyUndefs,
18951939
/*MinSplatBits=*/8) &&
18961940
SplatBitSize <= 64) {
@@ -5326,6 +5370,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
53265370
NODE_NAME_CASE(VSRLI)
53275371
NODE_NAME_CASE(VBSLL)
53285372
NODE_NAME_CASE(VBSRL)
5373+
NODE_NAME_CASE(VLDREPL)
53295374
}
53305375
#undef NODE_NAME_CASE
53315376
return nullptr;

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ enum NodeType : unsigned {
155155

156156
// Vector byte logicial left / right shift
157157
VBSLL,
158-
VBSRL
158+
VBSRL,
159+
160+
// Scalar load broadcast to vector
161+
VLDREPL
159162

160163
// Intrinsic operations end =============================================
161164
};

llvm/lib/Target/LoongArch/LoongArchInstrInfo.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,8 @@ def simm8_lsl # I : Operand<GRLenVT> {
307307
}
308308
}
309309

310-
def simm9_lsl3 : Operand<GRLenVT> {
310+
def simm9_lsl3 : Operand<GRLenVT>,
311+
ImmLeaf<GRLenVT, [{return isShiftedInt<9,3>(Imm);}]> {
311312
let ParserMatchClass = SImmAsmOperand<9, "lsl3">;
312313
let EncoderMethod = "getImmOpValueAsr<3>";
313314
let DecoderMethod = "decodeSImmOperand<9, 3>";
@@ -317,13 +318,15 @@ def simm10 : Operand<GRLenVT> {
317318
let ParserMatchClass = SImmAsmOperand<10>;
318319
}
319320

320-
def simm10_lsl2 : Operand<GRLenVT> {
321+
def simm10_lsl2 : Operand<GRLenVT>,
322+
ImmLeaf<GRLenVT, [{return isShiftedInt<10,2>(Imm);}]> {
321323
let ParserMatchClass = SImmAsmOperand<10, "lsl2">;
322324
let EncoderMethod = "getImmOpValueAsr<2>";
323325
let DecoderMethod = "decodeSImmOperand<10, 2>";
324326
}
325327

326-
def simm11_lsl1 : Operand<GRLenVT> {
328+
def simm11_lsl1 : Operand<GRLenVT>,
329+
ImmLeaf<GRLenVT, [{return isShiftedInt<11,1>(Imm);}]> {
327330
let ParserMatchClass = SImmAsmOperand<11, "lsl1">;
328331
let EncoderMethod = "getImmOpValueAsr<1>";
329332
let DecoderMethod = "decodeSImmOperand<11, 1>";

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,7 @@ def : Pat<(int_loongarch_lasx_xvld GPR:$rj, timm:$imm),
21652165
def : Pat<(int_loongarch_lasx_xvldx GPR:$rj, GPR:$rk),
21662166
(XVLDX GPR:$rj, GPR:$rk)>;
21672167

2168+
// xvldrepl
21682169
def : Pat<(int_loongarch_lasx_xvldrepl_b GPR:$rj, timm:$imm),
21692170
(XVLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
21702171
def : Pat<(int_loongarch_lasx_xvldrepl_h GPR:$rj, timm:$imm),
@@ -2174,6 +2175,11 @@ def : Pat<(int_loongarch_lasx_xvldrepl_w GPR:$rj, timm:$imm),
21742175
def : Pat<(int_loongarch_lasx_xvldrepl_d GPR:$rj, timm:$imm),
21752176
(XVLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
21762177

2178+
defm : VldreplPat<v32i8, XVLDREPL_B, simm12_addlike>;
2179+
defm : VldreplPat<v16i16, XVLDREPL_H, simm11_lsl1>;
2180+
defm : VldreplPat<v8i32, XVLDREPL_W, simm10_lsl2>;
2181+
defm : VldreplPat<v4i64, XVLDREPL_D, simm9_lsl3>;
2182+
21772183
// store
21782184
def : Pat<(int_loongarch_lasx_xvst LASX256:$xd, GPR:$rj, timm:$imm),
21792185
(XVST LASX256:$xd, GPR:$rj, (to_valid_timm timm:$imm))>;

llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def SDT_LoongArchV1RUimm: SDTypeProfile<1, 2, [SDTCisVec<0>,
2626
def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, SDTCisInt<1>]>;
2727
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
2828
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
29+
def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
2930

3031
// Target nodes.
3132
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -64,6 +65,10 @@ def loongarch_vsrli : SDNode<"LoongArchISD::VSRLI", SDT_LoongArchV1RUimm>;
6465
def loongarch_vbsll : SDNode<"LoongArchISD::VBSLL", SDT_LoongArchV1RUimm>;
6566
def loongarch_vbsrl : SDNode<"LoongArchISD::VBSRL", SDT_LoongArchV1RUimm>;
6667

68+
def loongarch_vldrepl
69+
: SDNode<"LoongArchISD::VLDREPL",
70+
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
71+
6772
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
6873
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
6974
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1433,6 +1438,14 @@ multiclass PatCCVrVrF<CondCode CC, string Inst> {
14331438
(!cast<LAInst>(Inst#"_D") LSX128:$vj, LSX128:$vk)>;
14341439
}
14351440

1441+
multiclass VldreplPat<ValueType vt, LAInst Inst, Operand ImmOpnd> {
1442+
def : Pat<(vt(loongarch_vldrepl BaseAddr:$rj)), (Inst BaseAddr:$rj, 0)>;
1443+
def : Pat<(vt(loongarch_vldrepl(AddrConstant GPR:$rj, ImmOpnd:$imm))),
1444+
(Inst GPR:$rj, ImmOpnd:$imm)>;
1445+
def : Pat<(vt(loongarch_vldrepl(AddLike BaseAddr:$rj, ImmOpnd:$imm))),
1446+
(Inst BaseAddr:$rj, ImmOpnd:$imm)>;
1447+
}
1448+
14361449
let Predicates = [HasExtLSX] in {
14371450

14381451
// VADD_{B/H/W/D}
@@ -2342,6 +2355,7 @@ def : Pat<(int_loongarch_lsx_vld GPR:$rj, timm:$imm),
23422355
def : Pat<(int_loongarch_lsx_vldx GPR:$rj, GPR:$rk),
23432356
(VLDX GPR:$rj, GPR:$rk)>;
23442357

2358+
// vldrepl
23452359
def : Pat<(int_loongarch_lsx_vldrepl_b GPR:$rj, timm:$imm),
23462360
(VLDREPL_B GPR:$rj, (to_valid_timm timm:$imm))>;
23472361
def : Pat<(int_loongarch_lsx_vldrepl_h GPR:$rj, timm:$imm),
@@ -2351,6 +2365,11 @@ def : Pat<(int_loongarch_lsx_vldrepl_w GPR:$rj, timm:$imm),
23512365
def : Pat<(int_loongarch_lsx_vldrepl_d GPR:$rj, timm:$imm),
23522366
(VLDREPL_D GPR:$rj, (to_valid_timm timm:$imm))>;
23532367

2368+
defm : VldreplPat<v16i8, VLDREPL_B, simm12_addlike>;
2369+
defm : VldreplPat<v8i16, VLDREPL_H, simm11_lsl1>;
2370+
defm : VldreplPat<v4i32, VLDREPL_W, simm10_lsl2>;
2371+
defm : VldreplPat<v2i64, VLDREPL_D, simm9_lsl3>;
2372+
23542373
// store
23552374
def : Pat<(int_loongarch_lsx_vst LSX128:$vd, GPR:$rj, timm:$imm),
23562375
(VST LSX128:$vd, GPR:$rj, (to_valid_timm timm:$imm))>;

0 commit comments

Comments
 (0)