Skip to content

Commit 6cec362

Browse files
4vtomatBeMg
andauthored
[RISCV] Support XSfmm LLVM IR and CodeGen (#143069)
stack on: #143068 Co-authored-by: Piyou Chen <[email protected]>
1 parent 812a225 commit 6cec362

30 files changed

+1141
-0
lines changed

llvm/include/llvm/IR/IntrinsicsRISCVXsf.td

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,98 @@ let TargetPrefix = "riscv" in {
180180
// XSfvfnrclipxfqf
181181
defm int_riscv_sf_vfnrclip_x_f_qf : RISCVSFCustomVFNRCLIP;
182182
defm int_riscv_sf_vfnrclip_xu_f_qf : RISCVSFCustomVFNRCLIP;
183+
184+
// XSfmm
185+
// Output: (output_len)
186+
// Input: (input_len, vsew, twiden)
187+
class RISCVSFVSet
188+
: DefaultAttrsIntrinsic<[llvm_anyint_ty],
189+
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
190+
[ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, IntrNoMem]>;
191+
192+
// Input: (tss, base, tn)
193+
class RISCVSFTileLoad
194+
: DefaultAttrsIntrinsic<[],
195+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
196+
[NoCapture<ArgIndex<1>>, IntrHasSideEffects]>,
197+
RISCVVIntrinsic;
198+
199+
// Input: (tss, base, tn)
200+
class RISCVSFTileStore
201+
: DefaultAttrsIntrinsic<[],
202+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
203+
[NoCapture<ArgIndex<1>>, IntrWriteMem,
204+
IntrHasSideEffects]>,
205+
RISCVVIntrinsic;
206+
207+
// Output: ()
208+
// Input: (mtd, mat1, mat2, tm, tn, tk, twiden)
209+
class RISCVSFCustomMatMul<bit is_float = false>
210+
: DefaultAttrsIntrinsic<[], [llvm_anyint_ty, llvm_anyvector_ty,
211+
!if(is_float, LLVMMatchType<1>,
212+
llvm_anyvector_ty),
213+
LLVMMatchType<0>, LLVMMatchType<0>,
214+
LLVMMatchType<0>, LLVMMatchType<0>],
215+
[IntrNoMem, IntrHasSideEffects,
216+
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<6>>]>,
217+
RISCVVIntrinsic;
218+
219+
def int_riscv_sf_vsettnt : RISCVSFVSet;
220+
def int_riscv_sf_vsettm : RISCVSFVSet;
221+
def int_riscv_sf_vsettk : RISCVSFVSet;
222+
223+
def int_riscv_sf_vlte8 : RISCVSFTileLoad;
224+
def int_riscv_sf_vlte16 : RISCVSFTileLoad;
225+
def int_riscv_sf_vlte32 : RISCVSFTileLoad;
226+
def int_riscv_sf_vlte64 : RISCVSFTileLoad;
227+
def int_riscv_sf_vste8 : RISCVSFTileStore;
228+
def int_riscv_sf_vste16 : RISCVSFTileStore;
229+
def int_riscv_sf_vste32 : RISCVSFTileStore;
230+
def int_riscv_sf_vste64 : RISCVSFTileStore;
231+
232+
// Output: (vd)
233+
// Input: (tss, tn)
234+
def int_riscv_sf_vtmv_v_t
235+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
236+
[llvm_anyint_ty, LLVMMatchType<1>],
237+
[IntrNoMem, IntrHasSideEffects]>,
238+
RISCVVIntrinsic {
239+
let VLOperand = 2;
240+
}
241+
// Output: ()
242+
// Input: (tss, vs2, tn)
243+
def int_riscv_sf_vtmv_t_v
244+
: DefaultAttrsIntrinsic<[], [LLVMMatchType<1>, llvm_anyvector_ty,
245+
llvm_anyint_ty], [IntrNoMem, IntrHasSideEffects]>,
246+
RISCVVIntrinsic {
247+
let VLOperand = 2;
248+
}
249+
250+
foreach a = ["u", "s"] in {
251+
foreach b = ["u", "s"] in {
252+
def int_riscv_sf_mm_ # a # _ # b : RISCVSFCustomMatMul;
253+
}
254+
}
255+
256+
def int_riscv_sf_mm_f_f : RISCVSFCustomMatMul<true>;
257+
foreach e1 = [5, 4] in
258+
foreach e2 = [5, 4] in
259+
def int_riscv_sf_mm_e # e1 # m # !sub(7, e1) # _e # e2 # m # !sub(7, e2)
260+
: RISCVSFCustomMatMul<true>;
261+
262+
// Output: ()
263+
// Input: (mtd)
264+
def int_riscv_sf_vtzero_t
265+
: DefaultAttrsIntrinsic<[],
266+
[llvm_anyint_ty, LLVMMatchType<0>,LLVMMatchType<0>,
267+
LLVMMatchType<0>, LLVMMatchType<0>],
268+
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<3>>,
269+
ImmArg<ArgIndex<4>>, IntrNoMem, IntrHasSideEffects]>,
270+
RISCVVIntrinsic;
271+
272+
// Output: ()
273+
// Input: ()
274+
def int_riscv_sf_vtdiscard
275+
: DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>,
276+
RISCVVIntrinsic;
183277
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,12 @@ static bool lowerRISCVVMachineInstrToMCInst(const MachineInstr *MI,
11001100
--NumOps;
11011101
if (RISCVII::hasRoundModeOp(TSFlags))
11021102
--NumOps;
1103+
if (RISCVII::hasTWidenOp(TSFlags))
1104+
--NumOps;
1105+
if (RISCVII::hasTMOp(TSFlags))
1106+
--NumOps;
1107+
if (RISCVII::hasTKOp(TSFlags))
1108+
--NumOps;
11031109

11041110
bool hasVLOutput = RISCVInstrInfo::isFaultOnlyFirstLoad(*MI);
11051111
for (unsigned OpNo = 0; OpNo != NumOps; ++OpNo) {

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,44 @@ void RISCVDAGToDAGISel::selectVSETVLI(SDNode *Node) {
516516
CurDAG->getMachineNode(Opcode, DL, XLenVT, VLOperand, VTypeIOp));
517517
}
518518

519+
void RISCVDAGToDAGISel::selectXSfmmVSET(SDNode *Node) {
520+
if (!Subtarget->hasVendorXSfmmbase())
521+
return;
522+
523+
assert(Node->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Unexpected opcode");
524+
525+
SDLoc DL(Node);
526+
MVT XLenVT = Subtarget->getXLenVT();
527+
528+
unsigned IntNo = Node->getConstantOperandVal(0);
529+
530+
assert((IntNo == Intrinsic::riscv_sf_vsettnt ||
531+
IntNo == Intrinsic::riscv_sf_vsettm ||
532+
IntNo == Intrinsic::riscv_sf_vsettk) &&
533+
"Unexpected XSfmm vset intrinsic");
534+
535+
unsigned SEW = RISCVVType::decodeVSEW(Node->getConstantOperandVal(2));
536+
unsigned Widen = RISCVVType::decodeTWiden(Node->getConstantOperandVal(3));
537+
unsigned PseudoOpCode =
538+
IntNo == Intrinsic::riscv_sf_vsettnt ? RISCV::PseudoSF_VSETTNT
539+
: IntNo == Intrinsic::riscv_sf_vsettm ? RISCV::PseudoSF_VSETTM
540+
: RISCV::PseudoSF_VSETTK;
541+
542+
if (IntNo == Intrinsic::riscv_sf_vsettnt) {
543+
unsigned VTypeI = RISCVVType::encodeXSfmmVType(SEW, Widen, 0);
544+
SDValue VTypeIOp = CurDAG->getTargetConstant(VTypeI, DL, XLenVT);
545+
546+
ReplaceNode(Node, CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
547+
Node->getOperand(1), VTypeIOp));
548+
} else {
549+
SDValue Log2SEW = CurDAG->getTargetConstant(Log2_32(SEW), DL, XLenVT);
550+
SDValue TWiden = CurDAG->getTargetConstant(Widen, DL, XLenVT);
551+
ReplaceNode(Node,
552+
CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
553+
Node->getOperand(1), Log2SEW, TWiden));
554+
}
555+
}
556+
519557
bool RISCVDAGToDAGISel::tryShrinkShlLogicImm(SDNode *Node) {
520558
MVT VT = Node->getSimpleValueType(0);
521559
unsigned Opcode = Node->getOpcode();
@@ -847,6 +885,11 @@ bool RISCVDAGToDAGISel::tryIndexedLoad(SDNode *Node) {
847885
return true;
848886
}
849887

888+
static Register getTileReg(uint64_t TileNum) {
889+
assert(TileNum <= 15 && "Invalid tile number");
890+
return RISCV::T0 + TileNum;
891+
}
892+
850893
void RISCVDAGToDAGISel::selectSF_VC_X_SE(SDNode *Node) {
851894
if (!Subtarget->hasVInstructions())
852895
return;
@@ -2035,6 +2078,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
20352078
case Intrinsic::riscv_vsetvli:
20362079
case Intrinsic::riscv_vsetvlimax:
20372080
return selectVSETVLI(Node);
2081+
case Intrinsic::riscv_sf_vsettnt:
2082+
case Intrinsic::riscv_sf_vsettm:
2083+
case Intrinsic::riscv_sf_vsettk:
2084+
return selectXSfmmVSET(Node);
20382085
}
20392086
break;
20402087
}
@@ -2458,6 +2505,142 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
24582505
case Intrinsic::riscv_sf_vc_i_se:
24592506
selectSF_VC_X_SE(Node);
24602507
return;
2508+
case Intrinsic::riscv_sf_vlte8:
2509+
case Intrinsic::riscv_sf_vlte16:
2510+
case Intrinsic::riscv_sf_vlte32:
2511+
case Intrinsic::riscv_sf_vlte64: {
2512+
unsigned Log2SEW;
2513+
unsigned PseudoInst;
2514+
switch (IntNo) {
2515+
case Intrinsic::riscv_sf_vlte8:
2516+
PseudoInst = RISCV::PseudoSF_VLTE8;
2517+
Log2SEW = 3;
2518+
break;
2519+
case Intrinsic::riscv_sf_vlte16:
2520+
PseudoInst = RISCV::PseudoSF_VLTE16;
2521+
Log2SEW = 4;
2522+
break;
2523+
case Intrinsic::riscv_sf_vlte32:
2524+
PseudoInst = RISCV::PseudoSF_VLTE32;
2525+
Log2SEW = 5;
2526+
break;
2527+
case Intrinsic::riscv_sf_vlte64:
2528+
PseudoInst = RISCV::PseudoSF_VLTE64;
2529+
Log2SEW = 6;
2530+
break;
2531+
}
2532+
2533+
SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
2534+
SDValue TWidenOp = CurDAG->getTargetConstant(1, DL, XLenVT);
2535+
SDValue Operands[] = {Node->getOperand(2),
2536+
Node->getOperand(3),
2537+
Node->getOperand(4),
2538+
SEWOp,
2539+
TWidenOp,
2540+
Node->getOperand(0)};
2541+
2542+
MachineSDNode *TileLoad =
2543+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2544+
if (auto *MemOp = dyn_cast<MemSDNode>(Node))
2545+
CurDAG->setNodeMemRefs(TileLoad, {MemOp->getMemOperand()});
2546+
2547+
ReplaceNode(Node, TileLoad);
2548+
return;
2549+
}
2550+
case Intrinsic::riscv_sf_mm_s_s:
2551+
case Intrinsic::riscv_sf_mm_s_u:
2552+
case Intrinsic::riscv_sf_mm_u_s:
2553+
case Intrinsic::riscv_sf_mm_u_u:
2554+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2555+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2556+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2557+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2558+
case Intrinsic::riscv_sf_mm_f_f: {
2559+
bool HasFRM = false;
2560+
unsigned PseudoInst;
2561+
switch (IntNo) {
2562+
case Intrinsic::riscv_sf_mm_s_s:
2563+
PseudoInst = RISCV::PseudoSF_MM_S_S;
2564+
break;
2565+
case Intrinsic::riscv_sf_mm_s_u:
2566+
PseudoInst = RISCV::PseudoSF_MM_S_U;
2567+
break;
2568+
case Intrinsic::riscv_sf_mm_u_s:
2569+
PseudoInst = RISCV::PseudoSF_MM_U_S;
2570+
break;
2571+
case Intrinsic::riscv_sf_mm_u_u:
2572+
PseudoInst = RISCV::PseudoSF_MM_U_U;
2573+
break;
2574+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2575+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E5M2;
2576+
HasFRM = true;
2577+
break;
2578+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2579+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E4M3;
2580+
HasFRM = true;
2581+
break;
2582+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2583+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E5M2;
2584+
HasFRM = true;
2585+
break;
2586+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2587+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E4M3;
2588+
HasFRM = true;
2589+
break;
2590+
case Intrinsic::riscv_sf_mm_f_f:
2591+
if (Node->getOperand(3).getValueType().getScalarType() == MVT::bf16)
2592+
PseudoInst = RISCV::PseudoSF_MM_F_F_ALT;
2593+
else
2594+
PseudoInst = RISCV::PseudoSF_MM_F_F;
2595+
HasFRM = true;
2596+
break;
2597+
}
2598+
uint64_t TileNum = Node->getConstantOperandVal(2);
2599+
SDValue Op1 = Node->getOperand(3);
2600+
SDValue Op2 = Node->getOperand(4);
2601+
MVT VT = Op1->getSimpleValueType(0);
2602+
unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
2603+
SDValue TmOp = Node->getOperand(5);
2604+
SDValue TnOp = Node->getOperand(6);
2605+
SDValue TkOp = Node->getOperand(7);
2606+
SDValue TWidenOp = Node->getOperand(8);
2607+
SDValue Chain = Node->getOperand(0);
2608+
2609+
// sf.mm.f.f with sew=32, twiden=2 is invalid
2610+
if (IntNo == Intrinsic::riscv_sf_mm_f_f && Log2SEW == 5 &&
2611+
TWidenOp->getAsZExtVal() == 2)
2612+
reportFatalUsageError("sf.mm.f.f doesn't support (sew=32, twiden=2)");
2613+
2614+
SmallVector<SDValue, 10> Operands(
2615+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Op1, Op2});
2616+
if (HasFRM)
2617+
Operands.push_back(
2618+
CurDAG->getTargetConstant(RISCVFPRndMode::DYN, DL, XLenVT));
2619+
Operands.append({TmOp, TnOp, TkOp,
2620+
CurDAG->getTargetConstant(Log2SEW, DL, XLenVT), TWidenOp,
2621+
Chain});
2622+
2623+
auto *NewNode =
2624+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2625+
2626+
ReplaceNode(Node, NewNode);
2627+
return;
2628+
}
2629+
case Intrinsic::riscv_sf_vtzero_t: {
2630+
uint64_t TileNum = Node->getConstantOperandVal(2);
2631+
SDValue Tm = Node->getOperand(3);
2632+
SDValue Tn = Node->getOperand(4);
2633+
SDValue Log2SEW = Node->getOperand(5);
2634+
SDValue TWiden = Node->getOperand(6);
2635+
SDValue Chain = Node->getOperand(0);
2636+
auto *NewNode = CurDAG->getMachineNode(
2637+
RISCV::PseudoSF_VTZERO_T, DL, Node->getVTList(),
2638+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Tm, Tn, Log2SEW,
2639+
TWiden, Chain});
2640+
2641+
ReplaceNode(Node, NewNode);
2642+
return;
2643+
}
24612644
}
24622645
break;
24632646
}

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
165165
void selectVSXSEG(SDNode *Node, unsigned NF, bool IsMasked, bool IsOrdered);
166166

167167
void selectVSETVLI(SDNode *Node);
168+
void selectXSfmmVSET(SDNode *Node);
168169

169170
void selectSF_VC_X_SE(SDNode *Node);
170171

llvm/lib/Target/RISCV/RISCVInstrInfoXSfmm.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,54 @@ let hasSideEffects = 1, mayLoad = 0, mayStore = 0 in {
418418
ixlenimm:$twiden)>;
419419
def PseudoSF_VTDISCARD : RISCVVPseudo<(outs), (ins), []>;
420420
}
421+
422+
class VPatXSfmmTileStore<string intrinsic_name,
423+
string inst_name,
424+
int log2sew> :
425+
Pat<(!cast<Intrinsic>(intrinsic_name)
426+
(XLenVT GPR:$rs2),
427+
(XLenVT GPR:$rs1),
428+
(XLenVT AVL:$tn)),
429+
(!cast<Instruction>(inst_name)
430+
(XLenVT GPR:$rs2),
431+
(XLenVT GPR:$rs1),
432+
GPR:$tn, log2sew, 1)>;
433+
434+
class VPatXSfmmTileMove_T_V<string intrinsic_name,
435+
string inst_name,
436+
ValueType reg_type,
437+
int log2sew> :
438+
Pat<(!cast<Intrinsic>(intrinsic_name)
439+
(XLenVT GPR:$rs1),
440+
(reg_type VRM8:$vs2),
441+
(XLenVT AVL:$atn)),
442+
(!cast<Instruction>(inst_name)
443+
(XLenVT GPR:$rs1),
444+
(reg_type VRM8:$vs2),
445+
GPR:$atn, log2sew, 1)>;
446+
447+
class VPatXSfmmTileMove_V_T<string intrinsic_name,
448+
string inst_name,
449+
ValueType result_type,
450+
int log2sew> :
451+
Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
452+
(XLenVT GPR:$rs1),
453+
(XLenVT AVL:$atn))),
454+
(!cast<Instruction>(inst_name)
455+
(XLenVT GPR:$rs1),
456+
GPR:$atn, log2sew, 1)>;
457+
458+
class VPatXSfmmVTDiscard<string intrinsic_name,
459+
string inst_name> :
460+
Pat<(!cast<Intrinsic>(intrinsic_name)),
461+
(!cast<Instruction>(inst_name))>;
462+
463+
foreach eew = [8, 16, 32, 64] in
464+
def : VPatXSfmmTileStore<"int_riscv_sf_vste" # eew, "PseudoSF_VSTE" # eew, !logtwo(eew)>;
465+
466+
foreach vti = [VI8M8, VI16M8, VI32M8, VI64M8, VF16M8, VF32M8, VF64M8, VBF16M8] in {
467+
def : VPatXSfmmTileMove_T_V<"int_riscv_sf_vtmv_t_v", "PseudoSF_VTMV_T_V", vti.Vector, vti.Log2SEW>;
468+
def : VPatXSfmmTileMove_V_T<"int_riscv_sf_vtmv_v_t", "PseudoSF_VTMV_V_T", vti.Vector, vti.Log2SEW>;
469+
}
470+
471+
def : VPatXSfmmVTDiscard<"int_riscv_sf_vtdiscard", "PseudoSF_VTDISCARD">;

0 commit comments

Comments
 (0)