Skip to content

Commit 3914b96

Browse files
4vtomatBeMg
andcommitted
[RISCV] Support XSfmm LLVM IR and CodeGen
Co-authored-by: Piyou Chen <[email protected]>
1 parent fd3c5d4 commit 3914b96

30 files changed

+1141
-0
lines changed

llvm/include/llvm/IR/IntrinsicsRISCVXsf.td

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,98 @@ let TargetPrefix = "riscv" in {
180180
// XSfvfnrclipxfqf
181181
defm int_riscv_sf_vfnrclip_x_f_qf : RISCVSFCustomVFNRCLIP;
182182
defm int_riscv_sf_vfnrclip_xu_f_qf : RISCVSFCustomVFNRCLIP;
183+
184+
// XSfmm
185+
// Output: (output_len)
186+
// Input: (input_len, vsew, twiden)
187+
class RISCVSFVSet
188+
: DefaultAttrsIntrinsic<[llvm_anyint_ty],
189+
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
190+
[ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, IntrNoMem]>;
191+
192+
// Input: (tss, base, tn)
193+
class RISCVSFTileLoad
194+
: DefaultAttrsIntrinsic<[],
195+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
196+
[NoCapture<ArgIndex<1>>, IntrHasSideEffects]>,
197+
RISCVVIntrinsic;
198+
199+
// Input: (tss, base, tn)
200+
class RISCVSFTileStore
201+
: DefaultAttrsIntrinsic<[],
202+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
203+
[NoCapture<ArgIndex<1>>, IntrWriteMem,
204+
IntrHasSideEffects]>,
205+
RISCVVIntrinsic;
206+
207+
// Output: ()
208+
// Input: (mtd, mat1, mat2, tm, tn, tk, twiden)
209+
class RISCVSFCustomMatMul<bit is_float = false>
210+
: DefaultAttrsIntrinsic<[], [llvm_anyint_ty, llvm_anyvector_ty,
211+
!if(is_float, LLVMMatchType<1>,
212+
llvm_anyvector_ty),
213+
LLVMMatchType<0>, LLVMMatchType<0>,
214+
LLVMMatchType<0>, LLVMMatchType<0>],
215+
[IntrNoMem, IntrHasSideEffects,
216+
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<6>>]>,
217+
RISCVVIntrinsic;
218+
219+
def int_riscv_sf_vsettnt : RISCVSFVSet;
220+
def int_riscv_sf_vsettm : RISCVSFVSet;
221+
def int_riscv_sf_vsettk : RISCVSFVSet;
222+
223+
def int_riscv_sf_vlte8 : RISCVSFTileLoad;
224+
def int_riscv_sf_vlte16 : RISCVSFTileLoad;
225+
def int_riscv_sf_vlte32 : RISCVSFTileLoad;
226+
def int_riscv_sf_vlte64 : RISCVSFTileLoad;
227+
def int_riscv_sf_vste8 : RISCVSFTileStore;
228+
def int_riscv_sf_vste16 : RISCVSFTileStore;
229+
def int_riscv_sf_vste32 : RISCVSFTileStore;
230+
def int_riscv_sf_vste64 : RISCVSFTileStore;
231+
232+
// Output: (vd)
233+
// Input: (tss, tn)
234+
def int_riscv_sf_vtmv_v_t
235+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
236+
[llvm_anyint_ty, LLVMMatchType<1>],
237+
[IntrNoMem, IntrHasSideEffects]>,
238+
RISCVVIntrinsic {
239+
let VLOperand = 2;
240+
}
241+
// Output: ()
242+
// Input: (tss, vs2, tn)
243+
def int_riscv_sf_vtmv_t_v
244+
: DefaultAttrsIntrinsic<[], [LLVMMatchType<1>, llvm_anyvector_ty,
245+
llvm_anyint_ty], [IntrNoMem, IntrHasSideEffects]>,
246+
RISCVVIntrinsic {
247+
let VLOperand = 2;
248+
}
249+
250+
foreach a = ["u", "s"] in {
251+
foreach b = ["u", "s"] in {
252+
def int_riscv_sf_mm_ # a # _ # b : RISCVSFCustomMatMul;
253+
}
254+
}
255+
256+
def int_riscv_sf_mm_f_f : RISCVSFCustomMatMul<true>;
257+
foreach e1 = [5, 4] in
258+
foreach e2 = [5, 4] in
259+
def int_riscv_sf_mm_e # e1 # m # !sub(7, e1) # _e # e2 # m # !sub(7, e2)
260+
: RISCVSFCustomMatMul<true>;
261+
262+
// Output: ()
263+
// Input: (mtd)
264+
def int_riscv_sf_vtzero_t
265+
: DefaultAttrsIntrinsic<[],
266+
[llvm_anyint_ty, LLVMMatchType<0>,LLVMMatchType<0>,
267+
LLVMMatchType<0>, LLVMMatchType<0>],
268+
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<3>>,
269+
ImmArg<ArgIndex<4>>, IntrNoMem, IntrHasSideEffects]>,
270+
RISCVVIntrinsic;
271+
272+
// Output: ()
273+
// Input: ()
274+
def int_riscv_sf_vtdiscard
275+
: DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>,
276+
RISCVVIntrinsic;
183277
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,12 @@ static bool lowerRISCVVMachineInstrToMCInst(const MachineInstr *MI,
11001100
--NumOps;
11011101
if (RISCVII::hasRoundModeOp(TSFlags))
11021102
--NumOps;
1103+
if (RISCVII::hasTWidenOp(TSFlags))
1104+
--NumOps;
1105+
if (RISCVII::hasTMOp(TSFlags))
1106+
--NumOps;
1107+
if (RISCVII::hasTKOp(TSFlags))
1108+
--NumOps;
11031109

11041110
bool hasVLOutput = RISCVInstrInfo::isFaultOnlyFirstLoad(*MI);
11051111
for (unsigned OpNo = 0; OpNo != NumOps; ++OpNo) {

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,44 @@ void RISCVDAGToDAGISel::selectVSETVLI(SDNode *Node) {
516516
CurDAG->getMachineNode(Opcode, DL, XLenVT, VLOperand, VTypeIOp));
517517
}
518518

519+
void RISCVDAGToDAGISel::selectXSfmmVSET(SDNode *Node) {
520+
if (!Subtarget->hasVendorXSfmmbase())
521+
return;
522+
523+
assert(Node->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Unexpected opcode");
524+
525+
SDLoc DL(Node);
526+
MVT XLenVT = Subtarget->getXLenVT();
527+
528+
unsigned IntNo = Node->getConstantOperandVal(0);
529+
530+
assert((IntNo == Intrinsic::riscv_sf_vsettnt ||
531+
IntNo == Intrinsic::riscv_sf_vsettm ||
532+
IntNo == Intrinsic::riscv_sf_vsettk) &&
533+
"Unexpected XSfmm vset intrinsic");
534+
535+
unsigned SEW = RISCVVType::decodeVSEW(Node->getConstantOperandVal(2));
536+
unsigned Widen = RISCVVType::decodeTWiden(Node->getConstantOperandVal(3));
537+
unsigned PseudoOpCode =
538+
IntNo == Intrinsic::riscv_sf_vsettnt ? RISCV::PseudoSF_VSETTNT
539+
: IntNo == Intrinsic::riscv_sf_vsettm ? RISCV::PseudoSF_VSETTM
540+
: RISCV::PseudoSF_VSETTK;
541+
542+
if (IntNo == Intrinsic::riscv_sf_vsettnt) {
543+
unsigned VTypeI = RISCVVType::encodeXSfmmVType(SEW, Widen, 0);
544+
SDValue VTypeIOp = CurDAG->getTargetConstant(VTypeI, DL, XLenVT);
545+
546+
ReplaceNode(Node, CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
547+
Node->getOperand(1), VTypeIOp));
548+
} else {
549+
SDValue Log2SEW = CurDAG->getTargetConstant(Log2_32(SEW), DL, XLenVT);
550+
SDValue TWiden = CurDAG->getTargetConstant(Widen, DL, XLenVT);
551+
ReplaceNode(Node,
552+
CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
553+
Node->getOperand(1), Log2SEW, TWiden));
554+
}
555+
}
556+
519557
bool RISCVDAGToDAGISel::tryShrinkShlLogicImm(SDNode *Node) {
520558
MVT VT = Node->getSimpleValueType(0);
521559
unsigned Opcode = Node->getOpcode();
@@ -936,6 +974,11 @@ bool RISCVDAGToDAGISel::tryIndexedLoad(SDNode *Node) {
936974
return true;
937975
}
938976

977+
static Register getTileReg(uint64_t TileNum) {
978+
assert(TileNum <= 15 && "Invalid tile number");
979+
return RISCV::T0 + TileNum;
980+
}
981+
939982
void RISCVDAGToDAGISel::selectSF_VC_X_SE(SDNode *Node) {
940983
if (!Subtarget->hasVInstructions())
941984
return;
@@ -2130,6 +2173,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
21302173
case Intrinsic::riscv_vsetvli:
21312174
case Intrinsic::riscv_vsetvlimax:
21322175
return selectVSETVLI(Node);
2176+
case Intrinsic::riscv_sf_vsettnt:
2177+
case Intrinsic::riscv_sf_vsettm:
2178+
case Intrinsic::riscv_sf_vsettk:
2179+
return selectXSfmmVSET(Node);
21332180
}
21342181
break;
21352182
}
@@ -2553,6 +2600,142 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
25532600
case Intrinsic::riscv_sf_vc_i_se:
25542601
selectSF_VC_X_SE(Node);
25552602
return;
2603+
case Intrinsic::riscv_sf_vlte8:
2604+
case Intrinsic::riscv_sf_vlte16:
2605+
case Intrinsic::riscv_sf_vlte32:
2606+
case Intrinsic::riscv_sf_vlte64: {
2607+
unsigned Log2SEW;
2608+
unsigned PseudoInst;
2609+
switch (IntNo) {
2610+
case Intrinsic::riscv_sf_vlte8:
2611+
PseudoInst = RISCV::PseudoSF_VLTE8;
2612+
Log2SEW = 3;
2613+
break;
2614+
case Intrinsic::riscv_sf_vlte16:
2615+
PseudoInst = RISCV::PseudoSF_VLTE16;
2616+
Log2SEW = 4;
2617+
break;
2618+
case Intrinsic::riscv_sf_vlte32:
2619+
PseudoInst = RISCV::PseudoSF_VLTE32;
2620+
Log2SEW = 5;
2621+
break;
2622+
case Intrinsic::riscv_sf_vlte64:
2623+
PseudoInst = RISCV::PseudoSF_VLTE64;
2624+
Log2SEW = 6;
2625+
break;
2626+
}
2627+
2628+
SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
2629+
SDValue TWidenOp = CurDAG->getTargetConstant(1, DL, XLenVT);
2630+
SDValue Operands[] = {Node->getOperand(2),
2631+
Node->getOperand(3),
2632+
Node->getOperand(4),
2633+
SEWOp,
2634+
TWidenOp,
2635+
Node->getOperand(0)};
2636+
2637+
MachineSDNode *TileLoad =
2638+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2639+
if (auto *MemOp = dyn_cast<MemSDNode>(Node))
2640+
CurDAG->setNodeMemRefs(TileLoad, {MemOp->getMemOperand()});
2641+
2642+
ReplaceNode(Node, TileLoad);
2643+
return;
2644+
}
2645+
case Intrinsic::riscv_sf_mm_s_s:
2646+
case Intrinsic::riscv_sf_mm_s_u:
2647+
case Intrinsic::riscv_sf_mm_u_s:
2648+
case Intrinsic::riscv_sf_mm_u_u:
2649+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2650+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2651+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2652+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2653+
case Intrinsic::riscv_sf_mm_f_f: {
2654+
bool HasFRM = false;
2655+
unsigned PseudoInst;
2656+
switch (IntNo) {
2657+
case Intrinsic::riscv_sf_mm_s_s:
2658+
PseudoInst = RISCV::PseudoSF_MM_S_S;
2659+
break;
2660+
case Intrinsic::riscv_sf_mm_s_u:
2661+
PseudoInst = RISCV::PseudoSF_MM_S_U;
2662+
break;
2663+
case Intrinsic::riscv_sf_mm_u_s:
2664+
PseudoInst = RISCV::PseudoSF_MM_U_S;
2665+
break;
2666+
case Intrinsic::riscv_sf_mm_u_u:
2667+
PseudoInst = RISCV::PseudoSF_MM_U_U;
2668+
break;
2669+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2670+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E5M2;
2671+
HasFRM = true;
2672+
break;
2673+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2674+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E4M3;
2675+
HasFRM = true;
2676+
break;
2677+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2678+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E5M2;
2679+
HasFRM = true;
2680+
break;
2681+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2682+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E4M3;
2683+
HasFRM = true;
2684+
break;
2685+
case Intrinsic::riscv_sf_mm_f_f:
2686+
if (Node->getOperand(3).getValueType().getScalarType() == MVT::bf16)
2687+
PseudoInst = RISCV::PseudoSF_MM_F_F_ALT;
2688+
else
2689+
PseudoInst = RISCV::PseudoSF_MM_F_F;
2690+
HasFRM = true;
2691+
break;
2692+
}
2693+
uint64_t TileNum = Node->getConstantOperandVal(2);
2694+
SDValue Op1 = Node->getOperand(3);
2695+
SDValue Op2 = Node->getOperand(4);
2696+
MVT VT = Op1->getSimpleValueType(0);
2697+
unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
2698+
SDValue TmOp = Node->getOperand(5);
2699+
SDValue TnOp = Node->getOperand(6);
2700+
SDValue TkOp = Node->getOperand(7);
2701+
SDValue TWidenOp = Node->getOperand(8);
2702+
SDValue Chain = Node->getOperand(0);
2703+
2704+
// sf.mm.f.f with sew=32, twiden=2 is invalid
2705+
if (IntNo == Intrinsic::riscv_sf_mm_f_f && Log2SEW == 5 &&
2706+
TWidenOp->getAsZExtVal() == 2)
2707+
reportFatalUsageError("sf.mm.f.f doesn't support (sew=32, twiden=2)");
2708+
2709+
SmallVector<SDValue, 10> Operands(
2710+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Op1, Op2});
2711+
if (HasFRM)
2712+
Operands.push_back(
2713+
CurDAG->getTargetConstant(RISCVFPRndMode::DYN, DL, XLenVT));
2714+
Operands.append({TmOp, TnOp, TkOp,
2715+
CurDAG->getTargetConstant(Log2SEW, DL, XLenVT), TWidenOp,
2716+
Chain});
2717+
2718+
auto *NewNode =
2719+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2720+
2721+
ReplaceNode(Node, NewNode);
2722+
return;
2723+
}
2724+
case Intrinsic::riscv_sf_vtzero_t: {
2725+
uint64_t TileNum = Node->getConstantOperandVal(2);
2726+
SDValue Tm = Node->getOperand(3);
2727+
SDValue Tn = Node->getOperand(4);
2728+
SDValue Log2SEW = Node->getOperand(5);
2729+
SDValue TWiden = Node->getOperand(6);
2730+
SDValue Chain = Node->getOperand(0);
2731+
auto *NewNode = CurDAG->getMachineNode(
2732+
RISCV::PseudoSF_VTZERO_T, DL, Node->getVTList(),
2733+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Tm, Tn, Log2SEW,
2734+
TWiden, Chain});
2735+
2736+
ReplaceNode(Node, NewNode);
2737+
return;
2738+
}
25562739
}
25572740
break;
25582741
}

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
167167
void selectVSXSEG(SDNode *Node, unsigned NF, bool IsMasked, bool IsOrdered);
168168

169169
void selectVSETVLI(SDNode *Node);
170+
void selectXSfmmVSET(SDNode *Node);
170171

171172
void selectSF_VC_X_SE(SDNode *Node);
172173

llvm/lib/Target/RISCV/RISCVInstrInfoXSfmm.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,3 +415,54 @@ let hasSideEffects = 1, mayLoad = 0, mayStore = 0 in {
415415
(ins TR:$rd, AVL:$atm, AVL:$atn, ixlenimm:$sew, ixlenimm:$twiden)>;
416416
def PseudoSF_VTDISCARD : RISCVVPseudo<(outs), (ins), []>;
417417
}
418+
419+
class VPatXSfmmTileStore<string intrinsic_name,
420+
string inst_name,
421+
int log2sew> :
422+
Pat<(!cast<Intrinsic>(intrinsic_name)
423+
(XLenVT GPR:$rs2),
424+
(XLenVT GPR:$rs1),
425+
(XLenVT AVL:$tn)),
426+
(!cast<Instruction>(inst_name)
427+
(XLenVT GPR:$rs2),
428+
(XLenVT GPR:$rs1),
429+
GPR:$tn, log2sew, 1)>;
430+
431+
class VPatXSfmmTileMove_T_V<string intrinsic_name,
432+
string inst_name,
433+
ValueType reg_type,
434+
int log2sew> :
435+
Pat<(!cast<Intrinsic>(intrinsic_name)
436+
(XLenVT GPR:$rs1),
437+
(reg_type VRM8:$vs2),
438+
(XLenVT AVL:$atn)),
439+
(!cast<Instruction>(inst_name)
440+
(XLenVT GPR:$rs1),
441+
(reg_type VRM8:$vs2),
442+
GPR:$atn, log2sew, 1)>;
443+
444+
class VPatXSfmmTileMove_V_T<string intrinsic_name,
445+
string inst_name,
446+
ValueType result_type,
447+
int log2sew> :
448+
Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
449+
(XLenVT GPR:$rs1),
450+
(XLenVT AVL:$atn))),
451+
(!cast<Instruction>(inst_name)
452+
(XLenVT GPR:$rs1),
453+
GPR:$atn, log2sew, 1)>;
454+
455+
class VPatXSfmmVTDiscard<string intrinsic_name,
456+
string inst_name> :
457+
Pat<(!cast<Intrinsic>(intrinsic_name)),
458+
(!cast<Instruction>(inst_name))>;
459+
460+
foreach eew = [8, 16, 32, 64] in
461+
def : VPatXSfmmTileStore<"int_riscv_sf_vste" # eew, "PseudoSF_VSTE" # eew, !logtwo(eew)>;
462+
463+
foreach vti = [VI8M8, VI16M8, VI32M8, VI64M8, VF16M8, VF32M8, VF64M8, VBF16M8] in {
464+
def : VPatXSfmmTileMove_T_V<"int_riscv_sf_vtmv_t_v", "PseudoSF_VTMV_T_V", vti.Vector, vti.Log2SEW>;
465+
def : VPatXSfmmTileMove_V_T<"int_riscv_sf_vtmv_v_t", "PseudoSF_VTMV_V_T", vti.Vector, vti.Log2SEW>;
466+
}
467+
468+
def : VPatXSfmmVTDiscard<"int_riscv_sf_vtdiscard", "PseudoSF_VTDISCARD">;

0 commit comments

Comments
 (0)