Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 74 additions & 68 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -345,49 +345,60 @@ defset list<VTypeInfo> AllVectors = {
}
}

defset list<VTypeInfo> AllFloatVectors = {
defset list<VTypeInfo> NoGroupFloatVectors = {
defset list<VTypeInfo> FractionalGroupFloatVectors = {
def VF16MF4: VTypeInfo<vfloat16mf4_t, vbool64_t, 16, V_MF4, f16, FPR16>;
def VF16MF2: VTypeInfo<vfloat16mf2_t, vbool32_t, 16, V_MF2, f16, FPR16>;
def VF32MF2: VTypeInfo<vfloat32mf2_t, vbool64_t, 32, V_MF2, f32, FPR32>;
def VBF16MF4: VTypeInfo<vbfloat16mf4_t, vbool64_t, 16, V_MF4, bf16, FPR16>;
def VBF16MF2: VTypeInfo<vbfloat16mf2_t, vbool32_t, 16, V_MF2, bf16, FPR16>;
defset list<VTypeInfo> AllFloatAndBFloatVectors = {
defset list<VTypeInfo> AllFloatVectors = {
defset list<VTypeInfo> NoGroupFloatVectors = {
defset list<VTypeInfo> FractionalGroupFloatVectors = {
def VF16MF4: VTypeInfo<vfloat16mf4_t, vbool64_t, 16, V_MF4, f16, FPR16>;
def VF16MF2: VTypeInfo<vfloat16mf2_t, vbool32_t, 16, V_MF2, f16, FPR16>;
def VF32MF2: VTypeInfo<vfloat32mf2_t, vbool64_t, 32, V_MF2, f32, FPR32>;
}
def VF16M1: VTypeInfo<vfloat16m1_t, vbool16_t, 16, V_M1, f16, FPR16>;
def VF32M1: VTypeInfo<vfloat32m1_t, vbool32_t, 32, V_M1, f32, FPR32>;
def VF64M1: VTypeInfo<vfloat64m1_t, vbool64_t, 64, V_M1, f64, FPR64>;
}

defset list<GroupVTypeInfo> GroupFloatVectors = {
def VF16M2: GroupVTypeInfo<vfloat16m2_t, vfloat16m1_t, vbool8_t, 16,
V_M2, f16, FPR16>;
def VF16M4: GroupVTypeInfo<vfloat16m4_t, vfloat16m1_t, vbool4_t, 16,
V_M4, f16, FPR16>;
def VF16M8: GroupVTypeInfo<vfloat16m8_t, vfloat16m1_t, vbool2_t, 16,
V_M8, f16, FPR16>;

def VF32M2: GroupVTypeInfo<vfloat32m2_t, vfloat32m1_t, vbool16_t, 32,
V_M2, f32, FPR32>;
def VF32M4: GroupVTypeInfo<vfloat32m4_t, vfloat32m1_t, vbool8_t, 32,
V_M4, f32, FPR32>;
def VF32M8: GroupVTypeInfo<vfloat32m8_t, vfloat32m1_t, vbool4_t, 32,
V_M8, f32, FPR32>;

def VF64M2: GroupVTypeInfo<vfloat64m2_t, vfloat64m1_t, vbool32_t, 64,
V_M2, f64, FPR64>;
def VF64M4: GroupVTypeInfo<vfloat64m4_t, vfloat64m1_t, vbool16_t, 64,
V_M4, f64, FPR64>;
def VF64M8: GroupVTypeInfo<vfloat64m8_t, vfloat64m1_t, vbool8_t, 64,
V_M8, f64, FPR64>;
}
def VF16M1: VTypeInfo<vfloat16m1_t, vbool16_t, 16, V_M1, f16, FPR16>;
def VF32M1: VTypeInfo<vfloat32m1_t, vbool32_t, 32, V_M1, f32, FPR32>;
def VF64M1: VTypeInfo<vfloat64m1_t, vbool64_t, 64, V_M1, f64, FPR64>;
def VBF16M1: VTypeInfo<vbfloat16m1_t, vbool16_t, 16, V_M1, bf16, FPR16>;
}

defset list<GroupVTypeInfo> GroupFloatVectors = {
def VF16M2: GroupVTypeInfo<vfloat16m2_t, vfloat16m1_t, vbool8_t, 16,
V_M2, f16, FPR16>;
def VF16M4: GroupVTypeInfo<vfloat16m4_t, vfloat16m1_t, vbool4_t, 16,
V_M4, f16, FPR16>;
def VF16M8: GroupVTypeInfo<vfloat16m8_t, vfloat16m1_t, vbool2_t, 16,
V_M8, f16, FPR16>;

def VF32M2: GroupVTypeInfo<vfloat32m2_t, vfloat32m1_t, vbool16_t, 32,
V_M2, f32, FPR32>;
def VF32M4: GroupVTypeInfo<vfloat32m4_t, vfloat32m1_t, vbool8_t, 32,
V_M4, f32, FPR32>;
def VF32M8: GroupVTypeInfo<vfloat32m8_t, vfloat32m1_t, vbool4_t, 32,
V_M8, f32, FPR32>;

def VF64M2: GroupVTypeInfo<vfloat64m2_t, vfloat64m1_t, vbool32_t, 64,
V_M2, f64, FPR64>;
def VF64M4: GroupVTypeInfo<vfloat64m4_t, vfloat64m1_t, vbool16_t, 64,
V_M4, f64, FPR64>;
def VF64M8: GroupVTypeInfo<vfloat64m8_t, vfloat64m1_t, vbool8_t, 64,
V_M8, f64, FPR64>;

def VBF16M2: GroupVTypeInfo<vbfloat16m2_t, vbfloat16m1_t, vbool8_t, 16,
V_M2, bf16, FPR16>;
def VBF16M4: GroupVTypeInfo<vbfloat16m4_t, vbfloat16m1_t, vbool4_t, 16,
V_M4, bf16, FPR16>;
def VBF16M8: GroupVTypeInfo<vbfloat16m8_t, vbfloat16m1_t, vbool2_t, 16,
V_M8, bf16, FPR16>;
defset list<VTypeInfo> AllBFloatVectors = {
defset list<VTypeInfo> NoGroupBFloatVectors = {
defset list<VTypeInfo> FractionalGroupBFloatVectors = {
def VBF16MF4: VTypeInfo<vbfloat16mf4_t, vbool64_t, 16, V_MF4, bf16, FPR16>;
def VBF16MF2: VTypeInfo<vbfloat16mf2_t, vbool32_t, 16, V_MF2, bf16, FPR16>;
}
def VBF16M1: VTypeInfo<vbfloat16m1_t, vbool16_t, 16, V_M1, bf16, FPR16>;
}

defset list<GroupVTypeInfo> GroupBFloatVectors = {
def VBF16M2: GroupVTypeInfo<vbfloat16m2_t, vbfloat16m1_t, vbool8_t, 16,
V_M2, bf16, FPR16>;
def VBF16M4: GroupVTypeInfo<vbfloat16m4_t, vbfloat16m1_t, vbool4_t, 16,
V_M4, bf16, FPR16>;
def VBF16M8: GroupVTypeInfo<vbfloat16m8_t, vbfloat16m1_t, vbool2_t, 16,
V_M8, bf16, FPR16>;
}
}
}
}
Expand Down Expand Up @@ -7143,31 +7154,32 @@ defm : VPatConversionVI_VF<"int_riscv_vfclass", "PseudoVFCLASS">;
// We can use vmerge.vvm to support vector-vector vfmerge.
// NOTE: Clang previously used int_riscv_vfmerge for vector-vector, but now uses
// int_riscv_vmerge. Support both for compatibility.
foreach vti = AllFloatVectors in {
foreach vti = AllFloatAndBFloatVectors in {
let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
defm : VPatBinaryCarryInTAIL<"int_riscv_vmerge", "PseudoVMERGE", "VVM",
vti.Vector,
vti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass,
vti.RegClass, vti.RegClass>;
let Predicates = GetVTypePredicates<vti>.Predicates in
defm : VPatBinaryCarryInTAIL<"int_riscv_vfmerge", "PseudoVFMERGE",
"V"#vti.ScalarSuffix#"M",
vti.Vector,
vti.Vector, vti.Scalar, vti.Mask,
vti.Log2SEW, vti.LMul, vti.RegClass,
vti.RegClass, vti.ScalarRegClass>;
}

foreach fvti = AllFloatVectors in {
defvar instr = !cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX);
let Predicates = GetVTypePredicates<fvti>.Predicates in
def : Pat<(fvti.Vector (int_riscv_vfmerge (fvti.Vector fvti.RegClass:$passthru),
(fvti.Vector fvti.RegClass:$rs2),
(fvti.Scalar (fpimm0)),
(fvti.Mask VMV0:$vm), VLOpFrag)),
(instr fvti.RegClass:$passthru, fvti.RegClass:$rs2, 0,
(fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW)>;
let Predicates = GetVTypePredicates<fvti>.Predicates in {
defm : VPatBinaryCarryInTAIL<"int_riscv_vfmerge", "PseudoVFMERGE",
"V"#fvti.ScalarSuffix#"M",
fvti.Vector,
fvti.Vector, fvti.Scalar, fvti.Mask,
fvti.Log2SEW, fvti.LMul, fvti.RegClass,
fvti.RegClass, fvti.ScalarRegClass>;

defvar instr = !cast<Instruction>("PseudoVMERGE_VIM_"#fvti.LMul.MX);
def : Pat<(fvti.Vector (int_riscv_vfmerge (fvti.Vector fvti.RegClass:$passthru),
(fvti.Vector fvti.RegClass:$rs2),
(fvti.Scalar (fpimm0)),
(fvti.Mask VMV0:$vm), VLOpFrag)),
(instr fvti.RegClass:$passthru, fvti.RegClass:$rs2, 0,
(fvti.Mask VMV0:$vm), GPR:$vl, fvti.Log2SEW)>;
}
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -7328,33 +7340,27 @@ foreach vti = NoGroupIntegerVectors in {
//===----------------------------------------------------------------------===//
// 16.3. Vector Slide Instructions
//===----------------------------------------------------------------------===//
defm : VPatTernaryV_VX_VI<"int_riscv_vslideup", "PseudoVSLIDEUP", AllIntegerVectors, uimm5>;
defm : VPatTernaryV_VX_VI<"int_riscv_vslidedown", "PseudoVSLIDEDOWN", AllIntegerVectors, uimm5>;
defm : VPatTernaryV_VX_VI<"int_riscv_vslideup", "PseudoVSLIDEUP", AllVectors, uimm5>;
defm : VPatTernaryV_VX_VI<"int_riscv_vslidedown", "PseudoVSLIDEDOWN", AllVectors, uimm5>;

defm : VPatBinaryV_VX<"int_riscv_vslide1up", "PseudoVSLIDE1UP", AllIntegerVectors>;
defm : VPatBinaryV_VX<"int_riscv_vslide1down", "PseudoVSLIDE1DOWN", AllIntegerVectors>;

defm : VPatTernaryV_VX_VI<"int_riscv_vslideup", "PseudoVSLIDEUP", AllFloatVectors, uimm5>;
defm : VPatTernaryV_VX_VI<"int_riscv_vslidedown", "PseudoVSLIDEDOWN", AllFloatVectors, uimm5>;
defm : VPatBinaryV_VX<"int_riscv_vfslide1up", "PseudoVFSLIDE1UP", AllFloatVectors>;
defm : VPatBinaryV_VX<"int_riscv_vfslide1down", "PseudoVFSLIDE1DOWN", AllFloatVectors>;

//===----------------------------------------------------------------------===//
// 16.4. Vector Register Gather Instructions
//===----------------------------------------------------------------------===//
defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER",
AllIntegerVectors, uimm5>;
AllVectors, uimm5>;
defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
eew=16, vtilist=AllIntegerVectors>;
eew=16, vtilist=AllVectors>;

defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER",
AllFloatVectors, uimm5>;
defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
eew=16, vtilist=AllFloatVectors>;
//===----------------------------------------------------------------------===//
// 16.5. Vector Compress Instruction
//===----------------------------------------------------------------------===//
defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllIntegerVectors>;
defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllFloatVectors>;
defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllVectors>;

// Include the non-intrinsic ISel patterns
include "RISCVInstrInfoVVLPatterns.td"
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ defm : VPatFPSetCCSDNode_VV_VF_FV<SETOLE, "PseudoVMFLE", "PseudoVMFGE">;
// Floating-point vselects:
// 11.15. Vector Integer Merge Instructions
// 13.15. Vector Floating-Point Merge Instruction
foreach fvti = AllFloatVectors in {
foreach fvti = AllFloatAndBFloatVectors in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm), fvti.RegClass:$rs1,
Expand All @@ -1397,7 +1397,12 @@ foreach fvti = AllFloatVectors in {
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask VMV0:$vm),
fvti.AVL, fvti.Log2SEW)>;
}
}

foreach fvti = AllFloatVectors in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm),
(SplatFPOp (SelectScalarFPAsInt (XLenVT GPR:$imm))),
fvti.RegClass:$rs2)),
Expand All @@ -1412,9 +1417,7 @@ foreach fvti = AllFloatVectors in {
(fvti.Vector (IMPLICIT_DEF)),
fvti.RegClass:$rs2, 0, (fvti.Mask VMV0:$vm), fvti.AVL, fvti.Log2SEW)>;
}
}

foreach fvti = AllFloatVectors in {
let Predicates = GetVTypePredicates<fvti>.Predicates in
def : Pat<(fvti.Vector (vselect (fvti.Mask VMV0:$vm),
(SplatFPOp fvti.ScalarRegClass:$rs1),
Expand Down
17 changes: 10 additions & 7 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -2423,10 +2423,10 @@ foreach vti = AllFloatVectors in {
}
}

foreach fvti = AllFloatVectors in {
// Floating-point vselects:
// 11.15. Vector Integer Merge Instructions
// 13.15. Vector Floating-Point Merge Instruction
// Floating-point vselects:
// 11.15. Vector Integer Merge Instructions
// 13.15. Vector Floating-Point Merge Instruction
foreach fvti = AllFloatAndBFloatVectors in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm),
Expand All @@ -2437,7 +2437,12 @@ foreach fvti = AllFloatVectors in {
(!cast<Instruction>("PseudoVMERGE_VVM_"#fvti.LMul.MX)
fvti.RegClass:$passthru, fvti.RegClass:$rs2, fvti.RegClass:$rs1, (fvti.Mask VMV0:$vm),
GPR:$vl, fvti.Log2SEW)>;
}
}

foreach fvti = AllFloatVectors in {
defvar ivti = GetIntVTypeInfo<fvti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm),
(SplatFPOp (SelectScalarFPAsInt (XLenVT GPR:$imm))),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From looking at the tests in fixed-vector-vpmerge-bf16.ll I guess we don't currently match vfmerge.vfm for bf16. Is there any reason why we can't if we have zfbfmin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not technically correct to move a bf16 with vfmerge.vfm. If the scalar register isn't properly nan-boxed the hardware will not create a bf16 nan, it will create an fp16 nan. Though, I don't know if llvm could ever create an improperly nan-boxed value.

There's this text about this in the Zvfbfa spec https://github.com/aswaterman/riscv-misc/blob/main/isa/zvfbfa.adoc

The instructions marked with †† differ only in that improperly NaN-boxed f-register operands must substitute the BF16 canonical NaN instead of the FP16 canonical NaN.

fvti.RegClass:$rs2,
Expand All @@ -2457,9 +2462,7 @@ foreach fvti = AllFloatVectors in {
fvti.RegClass:$passthru, fvti.RegClass:$rs2, 0, (fvti.Mask VMV0:$vm),
GPR:$vl, fvti.Log2SEW)>;
}
}

foreach fvti = AllFloatVectors in {
let Predicates = GetVTypePredicates<fvti>.Predicates in {
def : Pat<(fvti.Vector (riscv_vmerge_vl (fvti.Mask VMV0:$vm),
(SplatFPOp fvti.ScalarRegClass:$rs1),
Expand Down Expand Up @@ -2767,7 +2770,7 @@ foreach vti = NoGroupFloatVectors in {
}
}

foreach vti = AllFloatVectors in {
foreach vti = AllFloatAndBFloatVectors in {
defvar ivti = GetIntVTypeInfo<vti>.Vti;
let Predicates = GetVTypePredicates<ivti>.Predicates in {
def : Pat<(vti.Vector
Expand Down