Skip to content

Commit b67796f

Browse files
authored
[llvm][RISCV] Support Zvfbfa codegen for fneg, fabs and copysign (llvm#166944)
This is first patch for Zvfbfa codegen and I'm going to break it down to several patches to make it easier to reivew. The codegen supports both scalable vector and fixed length vector on both native operations and vp intrinsics.
1 parent f734ceb commit b67796f

14 files changed

+3344
-179
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ static cl::opt<bool>
8787
"be combined with a shift"),
8888
cl::init(true));
8989

90+
// TODO: Support more ops
91+
static const unsigned ZvfbfaVPOps[] = {ISD::VP_FNEG, ISD::VP_FABS,
92+
ISD::VP_FCOPYSIGN};
93+
static const unsigned ZvfbfaOps[] = {ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN};
94+
9095
RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
9196
const RISCVSubtarget &STI)
9297
: TargetLowering(TM), Subtarget(STI) {
@@ -1208,6 +1213,61 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12081213
}
12091214
};
12101215

1216+
// Sets common actions for zvfbfa, some of instructions are supported
1217+
// natively so that we don't need to promote them.
1218+
const auto SetZvfbfaActions = [&](MVT VT) {
1219+
setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom);
1220+
setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
1221+
Custom);
1222+
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1223+
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
1224+
setOperationAction({ISD::LROUND, ISD::LLROUND}, VT, Custom);
1225+
setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
1226+
Custom);
1227+
setOperationAction(ISD::SELECT_CC, VT, Expand);
1228+
setOperationAction({ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, VT, Custom);
1229+
setOperationAction({ISD::INSERT_VECTOR_ELT, ISD::CONCAT_VECTORS,
1230+
ISD::INSERT_SUBVECTOR, ISD::EXTRACT_SUBVECTOR,
1231+
ISD::VECTOR_DEINTERLEAVE, ISD::VECTOR_INTERLEAVE,
1232+
ISD::VECTOR_REVERSE, ISD::VECTOR_SPLICE,
1233+
ISD::VECTOR_COMPRESS},
1234+
VT, Custom);
1235+
setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
1236+
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
1237+
1238+
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
1239+
setOperationAction(ZvfbfaVPOps, VT, Custom);
1240+
1241+
MVT EltVT = VT.getVectorElementType();
1242+
if (isTypeLegal(EltVT))
1243+
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT,
1244+
ISD::EXTRACT_VECTOR_ELT},
1245+
VT, Custom);
1246+
else
1247+
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
1248+
EltVT, Custom);
1249+
setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
1250+
ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD,
1251+
ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
1252+
ISD::EXPERIMENTAL_VP_STRIDED_STORE, ISD::VP_GATHER,
1253+
ISD::VP_SCATTER},
1254+
VT, Custom);
1255+
setOperationAction(ISD::VP_LOAD_FF, VT, Custom);
1256+
1257+
// Expand FP operations that need libcalls.
1258+
setOperationAction(FloatingPointLibCallOps, VT, Expand);
1259+
1260+
// Custom split nxv32[b]f16 since nxv32[b]f32 is not legal.
1261+
if (getLMUL(VT) == RISCVVType::LMUL_8) {
1262+
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
1263+
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
1264+
} else {
1265+
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
1266+
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
1267+
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
1268+
}
1269+
};
1270+
12111271
if (Subtarget.hasVInstructionsF16()) {
12121272
for (MVT VT : F16VecVTs) {
12131273
if (!isTypeLegal(VT))
@@ -1222,7 +1282,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12221282
}
12231283
}
12241284

1225-
if (Subtarget.hasVInstructionsBF16Minimal()) {
1285+
if (Subtarget.hasVInstructionsBF16()) {
1286+
for (MVT VT : BF16VecVTs) {
1287+
if (!isTypeLegal(VT))
1288+
continue;
1289+
SetZvfbfaActions(VT);
1290+
}
1291+
} else if (Subtarget.hasVInstructionsBF16Minimal()) {
12261292
for (MVT VT : BF16VecVTs) {
12271293
if (!isTypeLegal(VT))
12281294
continue;
@@ -1501,6 +1567,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
15011567
// available.
15021568
setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
15031569
}
1570+
if (Subtarget.hasStdExtZvfbfa()) {
1571+
setOperationAction(ZvfbfaOps, VT, Custom);
1572+
setOperationAction(ZvfbfaVPOps, VT, Custom);
1573+
}
15041574
setOperationAction(
15051575
{ISD::VP_MERGE, ISD::VP_SELECT, ISD::VSELECT, ISD::SELECT}, VT,
15061576
Custom);
@@ -7245,7 +7315,11 @@ static bool isPromotedOpNeedingSplit(SDValue Op,
72457315
return (Op.getValueType() == MVT::nxv32f16 &&
72467316
(Subtarget.hasVInstructionsF16Minimal() &&
72477317
!Subtarget.hasVInstructionsF16())) ||
7248-
Op.getValueType() == MVT::nxv32bf16;
7318+
(Op.getValueType() == MVT::nxv32bf16 &&
7319+
Subtarget.hasVInstructionsBF16Minimal() &&
7320+
(!Subtarget.hasVInstructionsBF16() ||
7321+
(!llvm::is_contained(ZvfbfaOps, Op.getOpcode()) &&
7322+
!llvm::is_contained(ZvfbfaVPOps, Op.getOpcode()))));
72497323
}
72507324

72517325
static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) {

llvm/lib/Target/RISCV/RISCVInstrInfoZvfbf.td

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,5 +701,86 @@ let Predicates = [HasStdExtZvfbfa] in {
701701
FRM_DYN,
702702
fvti.AVL, fvti.Log2SEW, TA_MA)>;
703703
}
704-
}
704+
705+
foreach vti = AllBF16Vectors in {
706+
// 13.12. Vector Floating-Point Sign-Injection Instructions
707+
def : Pat<(fabs (vti.Vector vti.RegClass:$rs)),
708+
(!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
709+
(vti.Vector (IMPLICIT_DEF)),
710+
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
711+
// Handle fneg with VFSGNJN using the same input for both operands.
712+
def : Pat<(fneg (vti.Vector vti.RegClass:$rs)),
713+
(!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
714+
(vti.Vector (IMPLICIT_DEF)),
715+
vti.RegClass:$rs, vti.RegClass:$rs, vti.AVL, vti.Log2SEW, TA_MA)>;
716+
717+
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
718+
(vti.Vector vti.RegClass:$rs2))),
719+
(!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
720+
(vti.Vector (IMPLICIT_DEF)),
721+
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
722+
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
723+
(vti.Vector (SplatFPOp vti.ScalarRegClass:$rs2)))),
724+
(!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
725+
(vti.Vector (IMPLICIT_DEF)),
726+
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
727+
728+
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
729+
(vti.Vector (fneg vti.RegClass:$rs2)))),
730+
(!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
731+
(vti.Vector (IMPLICIT_DEF)),
732+
vti.RegClass:$rs1, vti.RegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
733+
def : Pat<(vti.Vector (fcopysign (vti.Vector vti.RegClass:$rs1),
734+
(vti.Vector (fneg (SplatFPOp vti.ScalarRegClass:$rs2))))),
735+
(!cast<Instruction>("PseudoVFSGNJN_ALT_V"#vti.ScalarSuffix#"_"#vti.LMul.MX#"_E"#vti.SEW)
736+
(vti.Vector (IMPLICIT_DEF)),
737+
vti.RegClass:$rs1, vti.ScalarRegClass:$rs2, vti.AVL, vti.Log2SEW, TA_MA)>;
738+
739+
// 13.12. Vector Floating-Point Sign-Injection Instructions
740+
def : Pat<(riscv_fabs_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
741+
VLOpFrag),
742+
(!cast<Instruction>("PseudoVFSGNJX_ALT_VV_"# vti.LMul.MX #"_E"#vti.SEW#"_MASK")
743+
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
744+
vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
745+
TA_MA)>;
746+
// Handle fneg with VFSGNJN using the same input for both operands.
747+
def : Pat<(riscv_fneg_vl (vti.Vector vti.RegClass:$rs), (vti.Mask VMV0:$vm),
748+
VLOpFrag),
749+
(!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW #"_MASK")
750+
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs,
751+
vti.RegClass:$rs, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
752+
TA_MA)>;
753+
754+
def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
755+
(vti.Vector vti.RegClass:$rs2),
756+
vti.RegClass:$passthru,
757+
(vti.Mask VMV0:$vm),
758+
VLOpFrag),
759+
(!cast<Instruction>("PseudoVFSGNJ_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
760+
vti.RegClass:$passthru, vti.RegClass:$rs1,
761+
vti.RegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
762+
TAIL_AGNOSTIC)>;
763+
764+
def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
765+
(riscv_fneg_vl vti.RegClass:$rs2,
766+
(vti.Mask true_mask),
767+
VLOpFrag),
768+
srcvalue,
769+
(vti.Mask true_mask),
770+
VLOpFrag),
771+
(!cast<Instruction>("PseudoVFSGNJN_ALT_VV_"# vti.LMul.MX#"_E"#vti.SEW)
772+
(vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1,
773+
vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TA_MA)>;
774+
775+
def : Pat<(riscv_fcopysign_vl (vti.Vector vti.RegClass:$rs1),
776+
(SplatFPOp vti.ScalarRegClass:$rs2),
777+
vti.RegClass:$passthru,
778+
(vti.Mask VMV0:$vm),
779+
VLOpFrag),
780+
(!cast<Instruction>("PseudoVFSGNJ_ALT_V"#vti.ScalarSuffix#"_"# vti.LMul.MX#"_E"#vti.SEW#"_MASK")
781+
vti.RegClass:$passthru, vti.RegClass:$rs1,
782+
vti.ScalarRegClass:$rs2, (vti.Mask VMV0:$vm), GPR:$vl, vti.Log2SEW,
783+
TAIL_AGNOSTIC)>;
784+
}
785+
}
705786
} // Predicates = [HasStdExtZvfbfa]
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvfbfa \
3+
; RUN: -target-abi=ilp32d -verify-machineinstrs < %s | FileCheck %s
4+
; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvfbfa \
5+
; RUN: -target-abi=lp64d -verify-machineinstrs < %s | FileCheck %s
6+
7+
define <2 x bfloat> @copysign_v2bf16(<2 x bfloat> %vm, <2 x bfloat> %vs) {
8+
; CHECK-LABEL: copysign_v2bf16:
9+
; CHECK: # %bb.0:
10+
; CHECK-NEXT: vsetivli zero, 2, e16alt, mf4, ta, ma
11+
; CHECK-NEXT: vfsgnj.vv v8, v8, v9
12+
; CHECK-NEXT: ret
13+
%r = call <2 x bfloat> @llvm.copysign.v2bf16(<2 x bfloat> %vm, <2 x bfloat> %vs)
14+
ret <2 x bfloat> %r
15+
}
16+
17+
define <4 x bfloat> @copysign_v4bf16(<4 x bfloat> %vm, <4 x bfloat> %vs) {
18+
; CHECK-LABEL: copysign_v4bf16:
19+
; CHECK: # %bb.0:
20+
; CHECK-NEXT: vsetivli zero, 4, e16alt, mf2, ta, ma
21+
; CHECK-NEXT: vfsgnj.vv v8, v8, v9
22+
; CHECK-NEXT: ret
23+
%r = call <4 x bfloat> @llvm.copysign.v4bf16(<4 x bfloat> %vm, <4 x bfloat> %vs)
24+
ret <4 x bfloat> %r
25+
}
26+
27+
define <8 x bfloat> @copysign_v8bf16(<8 x bfloat> %vm, <8 x bfloat> %vs) {
28+
; CHECK-LABEL: copysign_v8bf16:
29+
; CHECK: # %bb.0:
30+
; CHECK-NEXT: vsetivli zero, 8, e16alt, m1, ta, ma
31+
; CHECK-NEXT: vfsgnj.vv v8, v8, v9
32+
; CHECK-NEXT: ret
33+
%r = call <8 x bfloat> @llvm.copysign.v8bf16(<8 x bfloat> %vm, <8 x bfloat> %vs)
34+
ret <8 x bfloat> %r
35+
}
36+
37+
define <16 x bfloat> @copysign_v16bf16(<16 x bfloat> %vm, <16 x bfloat> %vs) {
38+
; CHECK-LABEL: copysign_v16bf16:
39+
; CHECK: # %bb.0:
40+
; CHECK-NEXT: vsetivli zero, 16, e16alt, m2, ta, ma
41+
; CHECK-NEXT: vfsgnj.vv v8, v8, v10
42+
; CHECK-NEXT: ret
43+
%r = call <16 x bfloat> @llvm.copysign.v16bf16(<16 x bfloat> %vm, <16 x bfloat> %vs)
44+
ret <16 x bfloat> %r
45+
}
46+
47+
define <32 x bfloat> @copysign_v32bf32(<32 x bfloat> %vm, <32 x bfloat> %vs) {
48+
; CHECK-LABEL: copysign_v32bf32:
49+
; CHECK: # %bb.0:
50+
; CHECK-NEXT: li a0, 32
51+
; CHECK-NEXT: vsetvli zero, a0, e16alt, m4, ta, ma
52+
; CHECK-NEXT: vfsgnj.vv v8, v8, v12
53+
; CHECK-NEXT: ret
54+
%r = call <32 x bfloat> @llvm.copysign.v32bf32(<32 x bfloat> %vm, <32 x bfloat> %vs)
55+
ret <32 x bfloat> %r
56+
}

0 commit comments

Comments
 (0)