Skip to content

Commit a07ace7

Browse files
4vtomatBeMg
andcommitted
[RISCV] Support XSfmm LLVM IR and CodeGen
Co-authored-by: Piyou Chen <[email protected]>
1 parent 7802640 commit a07ace7

30 files changed

+1141
-0
lines changed

llvm/include/llvm/IR/IntrinsicsRISCVXsf.td

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,99 @@ let TargetPrefix = "riscv" in {
180180
// XSfvfnrclipxfqf
181181
defm int_riscv_sf_vfnrclip_x_f_qf : RISCVSFCustomVFNRCLIP;
182182
defm int_riscv_sf_vfnrclip_xu_f_qf : RISCVSFCustomVFNRCLIP;
183+
184+
// XSfmm
185+
// Output: (output_len)
186+
// Input: (input_len, vsew, twiden)
187+
class RISCVSFVSet
188+
: DefaultAttrsIntrinsic<[llvm_anyint_ty],
189+
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
190+
[ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, IntrNoMem]>;
191+
192+
// Input: (tss, base, tn)
193+
// IntrReadMem, IntrHasSideEffects does not work for pattern matching.
194+
class RISCVSFTileLoad
195+
: DefaultAttrsIntrinsic<[],
196+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
197+
[NoCapture<ArgIndex<1>>]>,
198+
RISCVVIntrinsic;
199+
200+
// Input: (tss, base, tn)
201+
class RISCVSFTileStore
202+
: DefaultAttrsIntrinsic<[],
203+
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
204+
[NoCapture<ArgIndex<1>>, IntrWriteMem,
205+
IntrHasSideEffects]>,
206+
RISCVVIntrinsic;
207+
208+
// Output: ()
209+
// Input: (mtd, mat1, mat2, tm, tn, tk, twiden)
210+
class RISCVSFCustomMatMul<bit is_float = false>
211+
: DefaultAttrsIntrinsic<[], [llvm_anyint_ty, llvm_anyvector_ty,
212+
!if(is_float, LLVMMatchType<1>,
213+
llvm_anyvector_ty),
214+
LLVMMatchType<0>, LLVMMatchType<0>,
215+
LLVMMatchType<0>, LLVMMatchType<0>],
216+
[IntrNoMem, IntrHasSideEffects,
217+
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<6>>]>,
218+
RISCVVIntrinsic;
219+
220+
def int_riscv_sf_vsettnt : RISCVSFVSet;
221+
def int_riscv_sf_vsettm : RISCVSFVSet;
222+
def int_riscv_sf_vsettk : RISCVSFVSet;
223+
224+
def int_riscv_sf_vlte8 : RISCVSFTileLoad;
225+
def int_riscv_sf_vlte16 : RISCVSFTileLoad;
226+
def int_riscv_sf_vlte32 : RISCVSFTileLoad;
227+
def int_riscv_sf_vlte64 : RISCVSFTileLoad;
228+
def int_riscv_sf_vste8 : RISCVSFTileStore;
229+
def int_riscv_sf_vste16 : RISCVSFTileStore;
230+
def int_riscv_sf_vste32 : RISCVSFTileStore;
231+
def int_riscv_sf_vste64 : RISCVSFTileStore;
232+
233+
// Output: (vd)
234+
// Input: (tss, tn)
235+
def int_riscv_sf_vtmv_v_t
236+
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
237+
[llvm_anyint_ty, LLVMMatchType<1>],
238+
[IntrNoMem, IntrHasSideEffects]>,
239+
RISCVVIntrinsic {
240+
let VLOperand = 2;
241+
}
242+
// Output: ()
243+
// Input: (tss, vs2, tn)
244+
def int_riscv_sf_vtmv_t_v
245+
: DefaultAttrsIntrinsic<[], [LLVMMatchType<1>, llvm_anyvector_ty,
246+
llvm_anyint_ty], [IntrNoMem, IntrHasSideEffects]>,
247+
RISCVVIntrinsic {
248+
let VLOperand = 2;
249+
}
250+
251+
foreach a = ["u", "s"] in {
252+
foreach b = ["u", "s"] in {
253+
def int_riscv_sf_mm_ # a # _ # b : RISCVSFCustomMatMul;
254+
}
255+
}
256+
257+
def int_riscv_sf_mm_f_f : RISCVSFCustomMatMul<true>;
258+
foreach e1 = [5, 4] in
259+
foreach e2 = [5, 4] in
260+
def int_riscv_sf_mm_e # e1 # m # !sub(7, e1) # _e # e2 # m # !sub(7, e2)
261+
: RISCVSFCustomMatMul<true>;
262+
263+
// Output: ()
264+
// Input: (mtd)
265+
def int_riscv_sf_vtzero_t
266+
: DefaultAttrsIntrinsic<[],
267+
[llvm_anyint_ty, LLVMMatchType<0>,LLVMMatchType<0>,
268+
LLVMMatchType<0>, LLVMMatchType<0>],
269+
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<3>>,
270+
ImmArg<ArgIndex<4>>, IntrNoMem, IntrHasSideEffects]>,
271+
RISCVVIntrinsic;
272+
273+
// Output: ()
274+
// Input: ()
275+
def int_riscv_sf_vtdiscard
276+
: DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>,
277+
RISCVVIntrinsic;
183278
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,12 @@ static bool lowerRISCVVMachineInstrToMCInst(const MachineInstr *MI,
11001100
--NumOps;
11011101
if (RISCVII::hasRoundModeOp(TSFlags))
11021102
--NumOps;
1103+
if (RISCVII::hasTWidenOp(TSFlags))
1104+
--NumOps;
1105+
if (RISCVII::hasTMOp(TSFlags))
1106+
--NumOps;
1107+
if (RISCVII::hasTKOp(TSFlags))
1108+
--NumOps;
11031109

11041110
bool hasVLOutput = RISCV::isFaultFirstLoad(*MI);
11051111
for (unsigned OpNo = 0; OpNo != NumOps; ++OpNo) {

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,43 @@ void RISCVDAGToDAGISel::selectVSETVLI(SDNode *Node) {
522522
CurDAG->getMachineNode(Opcode, DL, XLenVT, VLOperand, VTypeIOp));
523523
}
524524

525+
void RISCVDAGToDAGISel::selectXSfmmVSET(SDNode *Node) {
526+
if (!Subtarget->hasVendorXSfmmbase())
527+
return;
528+
529+
assert(Node->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Unexpected opcode");
530+
531+
SDLoc DL(Node);
532+
MVT XLenVT = Subtarget->getXLenVT();
533+
534+
unsigned IntNo = Node->getConstantOperandVal(0);
535+
536+
assert((IntNo == Intrinsic::riscv_sf_vsettnt ||
537+
IntNo == Intrinsic::riscv_sf_vsettm ||
538+
IntNo == Intrinsic::riscv_sf_vsettk) &&
539+
"Unexpected XSfmm vset intrinsic");
540+
541+
unsigned SEW = RISCVVType::decodeVSEW(Node->getConstantOperandVal(2));
542+
unsigned Widen = RISCVVType::decodeTWiden(Node->getConstantOperandVal(3));
543+
unsigned PseudoOpCode =
544+
IntNo == Intrinsic::riscv_sf_vsettnt ? RISCV::PseudoSF_VSETTNT
545+
: IntNo == Intrinsic::riscv_sf_vsettm ? RISCV::PseudoSF_VSETTM
546+
: RISCV::PseudoSF_VSETTK;
547+
548+
unsigned VTypeI = RISCVVType::encodeXSfmmVType(SEW, Widen, 0);
549+
SDValue VTypeIOp = CurDAG->getTargetConstant(VTypeI, DL, XLenVT);
550+
SDValue Log2SEW = CurDAG->getTargetConstant(Log2_32(SEW), DL, XLenVT);
551+
SDValue TWiden = CurDAG->getTargetConstant(Widen, DL, XLenVT);
552+
553+
if (IntNo == Intrinsic::riscv_sf_vsettnt)
554+
ReplaceNode(Node, CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
555+
Node->getOperand(1), VTypeIOp));
556+
else
557+
ReplaceNode(Node,
558+
CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
559+
Node->getOperand(1), Log2SEW, TWiden));
560+
}
561+
525562
bool RISCVDAGToDAGISel::tryShrinkShlLogicImm(SDNode *Node) {
526563
MVT VT = Node->getSimpleValueType(0);
527564
unsigned Opcode = Node->getOpcode();
@@ -777,6 +814,11 @@ bool RISCVDAGToDAGISel::tryIndexedLoad(SDNode *Node) {
777814
return true;
778815
}
779816

817+
static Register getTileReg(uint64_t TileNum) {
818+
assert(TileNum <= 15 && "Invalid tile number");
819+
return RISCV::T0 + TileNum;
820+
}
821+
780822
void RISCVDAGToDAGISel::selectSF_VC_X_SE(SDNode *Node) {
781823
if (!Subtarget->hasVInstructions())
782824
return;
@@ -1955,6 +1997,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
19551997
case Intrinsic::riscv_vsetvli:
19561998
case Intrinsic::riscv_vsetvlimax:
19571999
return selectVSETVLI(Node);
2000+
case Intrinsic::riscv_sf_vsettnt:
2001+
case Intrinsic::riscv_sf_vsettm:
2002+
case Intrinsic::riscv_sf_vsettk:
2003+
return selectXSfmmVSET(Node);
19582004
}
19592005
break;
19602006
}
@@ -2352,6 +2398,142 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
23522398
case Intrinsic::riscv_sf_vc_i_se:
23532399
selectSF_VC_X_SE(Node);
23542400
return;
2401+
case Intrinsic::riscv_sf_vlte8:
2402+
case Intrinsic::riscv_sf_vlte16:
2403+
case Intrinsic::riscv_sf_vlte32:
2404+
case Intrinsic::riscv_sf_vlte64: {
2405+
unsigned Log2SEW;
2406+
unsigned PseudoInst;
2407+
switch (IntNo) {
2408+
case Intrinsic::riscv_sf_vlte8:
2409+
PseudoInst = RISCV::PseudoSF_VLTE8;
2410+
Log2SEW = 3;
2411+
break;
2412+
case Intrinsic::riscv_sf_vlte16:
2413+
PseudoInst = RISCV::PseudoSF_VLTE16;
2414+
Log2SEW = 4;
2415+
break;
2416+
case Intrinsic::riscv_sf_vlte32:
2417+
PseudoInst = RISCV::PseudoSF_VLTE32;
2418+
Log2SEW = 5;
2419+
break;
2420+
case Intrinsic::riscv_sf_vlte64:
2421+
PseudoInst = RISCV::PseudoSF_VLTE64;
2422+
Log2SEW = 6;
2423+
break;
2424+
}
2425+
2426+
SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
2427+
SDValue TWidenOp = CurDAG->getTargetConstant(1, DL, XLenVT);
2428+
SmallVector<SDValue, 7> Operands = {Node->getOperand(2),
2429+
Node->getOperand(3),
2430+
Node->getOperand(4),
2431+
SEWOp,
2432+
TWidenOp,
2433+
Node->getOperand(0)};
2434+
2435+
MachineSDNode *TileLoad =
2436+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2437+
if (auto *MemOp = dyn_cast<MemSDNode>(Node))
2438+
CurDAG->setNodeMemRefs(TileLoad, {MemOp->getMemOperand()});
2439+
2440+
ReplaceNode(Node, TileLoad);
2441+
return;
2442+
}
2443+
case Intrinsic::riscv_sf_mm_s_s:
2444+
case Intrinsic::riscv_sf_mm_s_u:
2445+
case Intrinsic::riscv_sf_mm_u_s:
2446+
case Intrinsic::riscv_sf_mm_u_u:
2447+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2448+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2449+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2450+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2451+
case Intrinsic::riscv_sf_mm_f_f: {
2452+
bool HasFRM = false;
2453+
unsigned PseudoInst;
2454+
switch (IntNo) {
2455+
case Intrinsic::riscv_sf_mm_s_s:
2456+
PseudoInst = RISCV::PseudoSF_MM_S_S;
2457+
break;
2458+
case Intrinsic::riscv_sf_mm_s_u:
2459+
PseudoInst = RISCV::PseudoSF_MM_S_U;
2460+
break;
2461+
case Intrinsic::riscv_sf_mm_u_s:
2462+
PseudoInst = RISCV::PseudoSF_MM_U_S;
2463+
break;
2464+
case Intrinsic::riscv_sf_mm_u_u:
2465+
PseudoInst = RISCV::PseudoSF_MM_U_U;
2466+
break;
2467+
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
2468+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E5M2;
2469+
HasFRM = true;
2470+
break;
2471+
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
2472+
PseudoInst = RISCV::PseudoSF_MM_E5M2_E4M3;
2473+
HasFRM = true;
2474+
break;
2475+
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
2476+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E5M2;
2477+
HasFRM = true;
2478+
break;
2479+
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
2480+
PseudoInst = RISCV::PseudoSF_MM_E4M3_E4M3;
2481+
HasFRM = true;
2482+
break;
2483+
case Intrinsic::riscv_sf_mm_f_f:
2484+
if (Node->getOperand(3).getValueType().getScalarType() == MVT::bf16)
2485+
PseudoInst = RISCV::PseudoSF_MM_F_F_ALT;
2486+
else
2487+
PseudoInst = RISCV::PseudoSF_MM_F_F;
2488+
HasFRM = true;
2489+
break;
2490+
}
2491+
uint64_t TileNum = Node->getConstantOperandVal(2);
2492+
SDValue Op1 = Node->getOperand(3);
2493+
SDValue Op2 = Node->getOperand(4);
2494+
MVT VT = Op1->getSimpleValueType(0);
2495+
unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
2496+
SDValue TmOp = Node->getOperand(5);
2497+
SDValue TnOp = Node->getOperand(6);
2498+
SDValue TkOp = Node->getOperand(7);
2499+
SDValue TWidenOp = Node->getOperand(8);
2500+
SDValue Chain = Node->getOperand(0);
2501+
2502+
// sf.mm.f.f with sew=32, twiden=2 is invalid
2503+
if (IntNo == Intrinsic::riscv_sf_mm_f_f && Log2SEW == 5 &&
2504+
TWidenOp->getAsZExtVal() == 2)
2505+
report_fatal_error("sf.mm.f.f doesn't support (sew=32, twiden=2)");
2506+
2507+
SmallVector<SDValue, 10> Operands(
2508+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Op1, Op2});
2509+
if (HasFRM)
2510+
Operands.push_back(
2511+
CurDAG->getTargetConstant(RISCVFPRndMode::DYN, DL, XLenVT));
2512+
Operands.append({TmOp, TnOp, TkOp,
2513+
CurDAG->getTargetConstant(Log2SEW, DL, XLenVT), TWidenOp,
2514+
Chain});
2515+
2516+
auto *NewNode =
2517+
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
2518+
2519+
ReplaceNode(Node, NewNode);
2520+
return;
2521+
}
2522+
case Intrinsic::riscv_sf_vtzero_t: {
2523+
uint64_t TileNum = Node->getConstantOperandVal(2);
2524+
SDValue Tm = Node->getOperand(3);
2525+
SDValue Tn = Node->getOperand(4);
2526+
SDValue Log2SEW = Node->getOperand(5);
2527+
SDValue TWiden = Node->getOperand(6);
2528+
SDValue Chain = Node->getOperand(0);
2529+
auto *NewNode = CurDAG->getMachineNode(
2530+
RISCV::PseudoSF_VTZERO_T, DL, Node->getVTList(),
2531+
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Tm, Tn, Log2SEW,
2532+
TWiden, Chain});
2533+
2534+
ReplaceNode(Node, NewNode);
2535+
return;
2536+
}
23552537
}
23562538
break;
23572539
}

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
164164
void selectVSXSEG(SDNode *Node, unsigned NF, bool IsMasked, bool IsOrdered);
165165

166166
void selectVSETVLI(SDNode *Node);
167+
void selectXSfmmVSET(SDNode *Node);
167168

168169
void selectSF_VC_X_SE(SDNode *Node);
169170

llvm/lib/Target/RISCV/RISCVInstrInfoXSfmm.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,54 @@ let hasSideEffects = 1, mayLoad = 0, mayStore = 0 in {
410410
: Pseudo<(outs), (ins TR:$rd, AVL:$atm, AVL:$atn, ixlenimm:$sew, ixlenimm:$twiden), []>, RISCVVPseudo;
411411
def PseudoSF_VTDISCARD : Pseudo<(outs), (ins), []>, RISCVVPseudo;
412412
}
413+
414+
class VPatXSfmmTileStore<string intrinsic_name,
415+
string inst_name,
416+
int log2sew> :
417+
Pat<(!cast<Intrinsic>(intrinsic_name)
418+
(XLenVT GPR:$rs2),
419+
(XLenVT GPR:$rs1),
420+
(XLenVT AVL:$tn)),
421+
(!cast<Instruction>(inst_name)
422+
(XLenVT GPR:$rs2),
423+
(XLenVT GPR:$rs1),
424+
GPR:$tn, log2sew, 1)>;
425+
426+
class VPatXSfmmTileMove_T_V<string intrinsic_name,
427+
string inst_name,
428+
ValueType reg_type,
429+
int log2sew> :
430+
Pat<(!cast<Intrinsic>(intrinsic_name)
431+
(XLenVT GPR:$rs1),
432+
(reg_type VRM8:$vs2),
433+
(XLenVT AVL:$atn)),
434+
(!cast<Instruction>(inst_name)
435+
(XLenVT GPR:$rs1),
436+
(reg_type VRM8:$vs2),
437+
GPR:$atn, log2sew, 1)>;
438+
439+
class VPatXSfmmTileMove_V_T<string intrinsic_name,
440+
string inst_name,
441+
ValueType result_type,
442+
int log2sew> :
443+
Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
444+
(XLenVT GPR:$rs1),
445+
(XLenVT AVL:$atn))),
446+
(!cast<Instruction>(inst_name)
447+
(XLenVT GPR:$rs1),
448+
GPR:$atn, log2sew, 1)>;
449+
450+
class VPatXSfmmVTDiscard<string intrinsic_name,
451+
string inst_name> :
452+
Pat<(!cast<Intrinsic>(intrinsic_name)),
453+
(!cast<Instruction>(inst_name))>;
454+
455+
foreach eew = [8, 16, 32, 64] in
456+
def : VPatXSfmmTileStore<"int_riscv_sf_vste" # eew, "PseudoSF_VSTE" # eew, !logtwo(eew)>;
457+
458+
foreach vti = [VI8M8, VI16M8, VI32M8, VI64M8, VF16M8, VF32M8, VF64M8, VBF16M8] in {
459+
def : VPatXSfmmTileMove_T_V<"int_riscv_sf_vtmv_t_v", "PseudoSF_VTMV_T_V", vti.Vector, vti.Log2SEW>;
460+
def : VPatXSfmmTileMove_V_T<"int_riscv_sf_vtmv_v_t", "PseudoSF_VTMV_V_T", vti.Vector, vti.Log2SEW>;
461+
}
462+
463+
def : VPatXSfmmVTDiscard<"int_riscv_sf_vtdiscard", "PseudoSF_VTDISCARD">;

0 commit comments

Comments
 (0)