Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsRISCVXsf.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,98 @@ let TargetPrefix = "riscv" in {
// XSfvfnrclipxfqf
defm int_riscv_sf_vfnrclip_x_f_qf : RISCVSFCustomVFNRCLIP;
defm int_riscv_sf_vfnrclip_xu_f_qf : RISCVSFCustomVFNRCLIP;

// XSfmm
// Output: (output_len)
// Input: (input_len, vsew, twiden)
class RISCVSFVSet
: DefaultAttrsIntrinsic<[llvm_anyint_ty],
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>],
[ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, IntrNoMem]>;

// Input: (tss, base, tn)
class RISCVSFTileLoad
: DefaultAttrsIntrinsic<[],
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
[NoCapture<ArgIndex<1>>, IntrHasSideEffects]>,
RISCVVIntrinsic;

// Input: (tss, base, tn)
class RISCVSFTileStore
: DefaultAttrsIntrinsic<[],
[llvm_anyint_ty, llvm_ptr_ty, LLVMMatchType<0>],
[NoCapture<ArgIndex<1>>, IntrWriteMem,
IntrHasSideEffects]>,
RISCVVIntrinsic;

// Output: ()
// Input: (mtd, mat1, mat2, tm, tn, tk, twiden)
class RISCVSFCustomMatMul<bit is_float = false>
: DefaultAttrsIntrinsic<[], [llvm_anyint_ty, llvm_anyvector_ty,
!if(is_float, LLVMMatchType<1>,
llvm_anyvector_ty),
LLVMMatchType<0>, LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>],
[IntrNoMem, IntrHasSideEffects,
ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<6>>]>,
RISCVVIntrinsic;

def int_riscv_sf_vsettnt : RISCVSFVSet;
def int_riscv_sf_vsettm : RISCVSFVSet;
def int_riscv_sf_vsettk : RISCVSFVSet;

def int_riscv_sf_vlte8 : RISCVSFTileLoad;
def int_riscv_sf_vlte16 : RISCVSFTileLoad;
def int_riscv_sf_vlte32 : RISCVSFTileLoad;
def int_riscv_sf_vlte64 : RISCVSFTileLoad;
def int_riscv_sf_vste8 : RISCVSFTileStore;
def int_riscv_sf_vste16 : RISCVSFTileStore;
def int_riscv_sf_vste32 : RISCVSFTileStore;
def int_riscv_sf_vste64 : RISCVSFTileStore;

// Output: (vd)
// Input: (tss, tn)
def int_riscv_sf_vtmv_v_t
: DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[llvm_anyint_ty, LLVMMatchType<1>],
[IntrNoMem, IntrHasSideEffects]>,
RISCVVIntrinsic {
let VLOperand = 2;
}
// Output: ()
// Input: (tss, vs2, tn)
def int_riscv_sf_vtmv_t_v
: DefaultAttrsIntrinsic<[], [LLVMMatchType<1>, llvm_anyvector_ty,
llvm_anyint_ty], [IntrNoMem, IntrHasSideEffects]>,
RISCVVIntrinsic {
let VLOperand = 2;
}

foreach a = ["u", "s"] in {
foreach b = ["u", "s"] in {
def int_riscv_sf_mm_ # a # _ # b : RISCVSFCustomMatMul;
}
}

def int_riscv_sf_mm_f_f : RISCVSFCustomMatMul<true>;
foreach e1 = [5, 4] in
foreach e2 = [5, 4] in
def int_riscv_sf_mm_e # e1 # m # !sub(7, e1) # _e # e2 # m # !sub(7, e2)
: RISCVSFCustomMatMul<true>;

// Output: ()
// Input: (mtd)
def int_riscv_sf_vtzero_t
: DefaultAttrsIntrinsic<[],
[llvm_anyint_ty, LLVMMatchType<0>,LLVMMatchType<0>,
LLVMMatchType<0>, LLVMMatchType<0>],
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<3>>,
ImmArg<ArgIndex<4>>, IntrNoMem, IntrHasSideEffects]>,
RISCVVIntrinsic;

// Output: ()
// Input: ()
def int_riscv_sf_vtdiscard
: DefaultAttrsIntrinsic<[], [], [IntrNoMem, IntrHasSideEffects]>,
RISCVVIntrinsic;
} // TargetPrefix = "riscv"
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/RISCVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,12 @@ static bool lowerRISCVVMachineInstrToMCInst(const MachineInstr *MI,
--NumOps;
if (RISCVII::hasRoundModeOp(TSFlags))
--NumOps;
if (RISCVII::hasTWidenOp(TSFlags))
--NumOps;
if (RISCVII::hasTMOp(TSFlags))
--NumOps;
if (RISCVII::hasTKOp(TSFlags))
--NumOps;

bool hasVLOutput = RISCVInstrInfo::isFaultOnlyFirstLoad(*MI);
for (unsigned OpNo = 0; OpNo != NumOps; ++OpNo) {
Expand Down
183 changes: 183 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,44 @@ void RISCVDAGToDAGISel::selectVSETVLI(SDNode *Node) {
CurDAG->getMachineNode(Opcode, DL, XLenVT, VLOperand, VTypeIOp));
}

void RISCVDAGToDAGISel::selectXSfmmVSET(SDNode *Node) {
if (!Subtarget->hasVendorXSfmmbase())
return;

assert(Node->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Unexpected opcode");

SDLoc DL(Node);
MVT XLenVT = Subtarget->getXLenVT();

unsigned IntNo = Node->getConstantOperandVal(0);

assert((IntNo == Intrinsic::riscv_sf_vsettnt ||
IntNo == Intrinsic::riscv_sf_vsettm ||
IntNo == Intrinsic::riscv_sf_vsettk) &&
"Unexpected XSfmm vset intrinsic");

unsigned SEW = RISCVVType::decodeVSEW(Node->getConstantOperandVal(2));
unsigned Widen = RISCVVType::decodeTWiden(Node->getConstantOperandVal(3));
unsigned PseudoOpCode =
IntNo == Intrinsic::riscv_sf_vsettnt ? RISCV::PseudoSF_VSETTNT
: IntNo == Intrinsic::riscv_sf_vsettm ? RISCV::PseudoSF_VSETTM
: RISCV::PseudoSF_VSETTK;

if (IntNo == Intrinsic::riscv_sf_vsettnt) {
unsigned VTypeI = RISCVVType::encodeXSfmmVType(SEW, Widen, 0);
SDValue VTypeIOp = CurDAG->getTargetConstant(VTypeI, DL, XLenVT);

ReplaceNode(Node, CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
Node->getOperand(1), VTypeIOp));
} else {
SDValue Log2SEW = CurDAG->getTargetConstant(Log2_32(SEW), DL, XLenVT);
SDValue TWiden = CurDAG->getTargetConstant(Widen, DL, XLenVT);
ReplaceNode(Node,
CurDAG->getMachineNode(PseudoOpCode, DL, XLenVT,
Node->getOperand(1), Log2SEW, TWiden));
}
}

bool RISCVDAGToDAGISel::tryShrinkShlLogicImm(SDNode *Node) {
MVT VT = Node->getSimpleValueType(0);
unsigned Opcode = Node->getOpcode();
Expand Down Expand Up @@ -847,6 +885,11 @@ bool RISCVDAGToDAGISel::tryIndexedLoad(SDNode *Node) {
return true;
}

static Register getTileReg(uint64_t TileNum) {
assert(TileNum <= 15 && "Invalid tile number");
return RISCV::T0 + TileNum;
}

void RISCVDAGToDAGISel::selectSF_VC_X_SE(SDNode *Node) {
if (!Subtarget->hasVInstructions())
return;
Expand Down Expand Up @@ -2035,6 +2078,10 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
case Intrinsic::riscv_vsetvli:
case Intrinsic::riscv_vsetvlimax:
return selectVSETVLI(Node);
case Intrinsic::riscv_sf_vsettnt:
case Intrinsic::riscv_sf_vsettm:
case Intrinsic::riscv_sf_vsettk:
return selectXSfmmVSET(Node);
}
break;
}
Expand Down Expand Up @@ -2458,6 +2505,142 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
case Intrinsic::riscv_sf_vc_i_se:
selectSF_VC_X_SE(Node);
return;
case Intrinsic::riscv_sf_vlte8:
case Intrinsic::riscv_sf_vlte16:
case Intrinsic::riscv_sf_vlte32:
case Intrinsic::riscv_sf_vlte64: {
unsigned Log2SEW;
unsigned PseudoInst;
switch (IntNo) {
case Intrinsic::riscv_sf_vlte8:
PseudoInst = RISCV::PseudoSF_VLTE8;
Log2SEW = 3;
break;
case Intrinsic::riscv_sf_vlte16:
PseudoInst = RISCV::PseudoSF_VLTE16;
Log2SEW = 4;
break;
case Intrinsic::riscv_sf_vlte32:
PseudoInst = RISCV::PseudoSF_VLTE32;
Log2SEW = 5;
break;
case Intrinsic::riscv_sf_vlte64:
PseudoInst = RISCV::PseudoSF_VLTE64;
Log2SEW = 6;
break;
}

SDValue SEWOp = CurDAG->getTargetConstant(Log2SEW, DL, XLenVT);
SDValue TWidenOp = CurDAG->getTargetConstant(1, DL, XLenVT);
SDValue Operands[] = {Node->getOperand(2),
Node->getOperand(3),
Node->getOperand(4),
SEWOp,
TWidenOp,
Node->getOperand(0)};

MachineSDNode *TileLoad =
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);
if (auto *MemOp = dyn_cast<MemSDNode>(Node))
CurDAG->setNodeMemRefs(TileLoad, {MemOp->getMemOperand()});

ReplaceNode(Node, TileLoad);
return;
}
case Intrinsic::riscv_sf_mm_s_s:
case Intrinsic::riscv_sf_mm_s_u:
case Intrinsic::riscv_sf_mm_u_s:
case Intrinsic::riscv_sf_mm_u_u:
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
case Intrinsic::riscv_sf_mm_f_f: {
bool HasFRM = false;
unsigned PseudoInst;
switch (IntNo) {
case Intrinsic::riscv_sf_mm_s_s:
PseudoInst = RISCV::PseudoSF_MM_S_S;
break;
case Intrinsic::riscv_sf_mm_s_u:
PseudoInst = RISCV::PseudoSF_MM_S_U;
break;
case Intrinsic::riscv_sf_mm_u_s:
PseudoInst = RISCV::PseudoSF_MM_U_S;
break;
case Intrinsic::riscv_sf_mm_u_u:
PseudoInst = RISCV::PseudoSF_MM_U_U;
break;
case Intrinsic::riscv_sf_mm_e5m2_e5m2:
PseudoInst = RISCV::PseudoSF_MM_E5M2_E5M2;
HasFRM = true;
break;
case Intrinsic::riscv_sf_mm_e5m2_e4m3:
PseudoInst = RISCV::PseudoSF_MM_E5M2_E4M3;
HasFRM = true;
break;
case Intrinsic::riscv_sf_mm_e4m3_e5m2:
PseudoInst = RISCV::PseudoSF_MM_E4M3_E5M2;
HasFRM = true;
break;
case Intrinsic::riscv_sf_mm_e4m3_e4m3:
PseudoInst = RISCV::PseudoSF_MM_E4M3_E4M3;
HasFRM = true;
break;
case Intrinsic::riscv_sf_mm_f_f:
if (Node->getOperand(3).getValueType().getScalarType() == MVT::bf16)
PseudoInst = RISCV::PseudoSF_MM_F_F_ALT;
else
PseudoInst = RISCV::PseudoSF_MM_F_F;
HasFRM = true;
break;
}
uint64_t TileNum = Node->getConstantOperandVal(2);
SDValue Op1 = Node->getOperand(3);
SDValue Op2 = Node->getOperand(4);
MVT VT = Op1->getSimpleValueType(0);
unsigned Log2SEW = Log2_32(VT.getScalarSizeInBits());
SDValue TmOp = Node->getOperand(5);
SDValue TnOp = Node->getOperand(6);
SDValue TkOp = Node->getOperand(7);
SDValue TWidenOp = Node->getOperand(8);
SDValue Chain = Node->getOperand(0);

// sf.mm.f.f with sew=32, twiden=2 is invalid
if (IntNo == Intrinsic::riscv_sf_mm_f_f && Log2SEW == 5 &&
TWidenOp->getAsZExtVal() == 2)
reportFatalUsageError("sf.mm.f.f doesn't support (sew=32, twiden=2)");

SmallVector<SDValue, 10> Operands(
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Op1, Op2});
if (HasFRM)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to add FRM to the intrinsic operands?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think I'll list as TODO

Operands.push_back(
CurDAG->getTargetConstant(RISCVFPRndMode::DYN, DL, XLenVT));
Operands.append({TmOp, TnOp, TkOp,
CurDAG->getTargetConstant(Log2SEW, DL, XLenVT), TWidenOp,
Chain});

auto *NewNode =
CurDAG->getMachineNode(PseudoInst, DL, Node->getVTList(), Operands);

ReplaceNode(Node, NewNode);
return;
}
case Intrinsic::riscv_sf_vtzero_t: {
uint64_t TileNum = Node->getConstantOperandVal(2);
SDValue Tm = Node->getOperand(3);
SDValue Tn = Node->getOperand(4);
SDValue Log2SEW = Node->getOperand(5);
SDValue TWiden = Node->getOperand(6);
SDValue Chain = Node->getOperand(0);
auto *NewNode = CurDAG->getMachineNode(
RISCV::PseudoSF_VTZERO_T, DL, Node->getVTList(),
{CurDAG->getRegister(getTileReg(TileNum), XLenVT), Tm, Tn, Log2SEW,
TWiden, Chain});

ReplaceNode(Node, NewNode);
return;
}
}
break;
}
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
void selectVSXSEG(SDNode *Node, unsigned NF, bool IsMasked, bool IsOrdered);

void selectVSETVLI(SDNode *Node);
void selectXSfmmVSET(SDNode *Node);

void selectSF_VC_X_SE(SDNode *Node);

Expand Down
51 changes: 51 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoXSfmm.td
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,54 @@ let hasSideEffects = 1, mayLoad = 0, mayStore = 0 in {
ixlenimm:$twiden)>;
def PseudoSF_VTDISCARD : RISCVVPseudo<(outs), (ins), []>;
}

class VPatXSfmmTileStore<string intrinsic_name,
string inst_name,
int log2sew> :
Pat<(!cast<Intrinsic>(intrinsic_name)
(XLenVT GPR:$rs2),
(XLenVT GPR:$rs1),
(XLenVT AVL:$tn)),
(!cast<Instruction>(inst_name)
(XLenVT GPR:$rs2),
(XLenVT GPR:$rs1),
GPR:$tn, log2sew, 1)>;

class VPatXSfmmTileMove_T_V<string intrinsic_name,
string inst_name,
ValueType reg_type,
int log2sew> :
Pat<(!cast<Intrinsic>(intrinsic_name)
(XLenVT GPR:$rs1),
(reg_type VRM8:$vs2),
(XLenVT AVL:$atn)),
(!cast<Instruction>(inst_name)
(XLenVT GPR:$rs1),
(reg_type VRM8:$vs2),
GPR:$atn, log2sew, 1)>;

class VPatXSfmmTileMove_V_T<string intrinsic_name,
string inst_name,
ValueType result_type,
int log2sew> :
Pat<(result_type (!cast<Intrinsic>(intrinsic_name)
(XLenVT GPR:$rs1),
(XLenVT AVL:$atn))),
(!cast<Instruction>(inst_name)
(XLenVT GPR:$rs1),
GPR:$atn, log2sew, 1)>;

class VPatXSfmmVTDiscard<string intrinsic_name,
string inst_name> :
Pat<(!cast<Intrinsic>(intrinsic_name)),
(!cast<Instruction>(inst_name))>;

foreach eew = [8, 16, 32, 64] in
def : VPatXSfmmTileStore<"int_riscv_sf_vste" # eew, "PseudoSF_VSTE" # eew, !logtwo(eew)>;

foreach vti = [VI8M8, VI16M8, VI32M8, VI64M8, VF16M8, VF32M8, VF64M8, VBF16M8] in {
def : VPatXSfmmTileMove_T_V<"int_riscv_sf_vtmv_t_v", "PseudoSF_VTMV_T_V", vti.Vector, vti.Log2SEW>;
def : VPatXSfmmTileMove_V_T<"int_riscv_sf_vtmv_v_t", "PseudoSF_VTMV_V_T", vti.Vector, vti.Log2SEW>;
}

def : VPatXSfmmVTDiscard<"int_riscv_sf_vtdiscard", "PseudoSF_VTDISCARD">;
Loading