Skip to content

Commit 9fe0a70

Browse files
authored
[llvm][RISCV] Support splat and vp_splat for zvfbfa codegen (llvm#167920)
1 parent 6eab083 commit 9fe0a70

File tree

7 files changed

+449
-151
lines changed

7 files changed

+449
-151
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,10 @@ static cl::opt<bool>
8888
cl::init(true));
8989

9090
// TODO: Support more ops
91-
static const unsigned ZvfbfaVPOps[] = {ISD::VP_FNEG, ISD::VP_FABS,
92-
ISD::VP_FCOPYSIGN};
93-
static const unsigned ZvfbfaOps[] = {ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN};
91+
static const unsigned ZvfbfaVPOps[] = {
92+
ISD::VP_FNEG, ISD::VP_FABS, ISD::VP_FCOPYSIGN, ISD::EXPERIMENTAL_VP_SPLAT};
93+
static const unsigned ZvfbfaOps[] = {ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN,
94+
ISD::SPLAT_VECTOR};
9495

9596
RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
9697
const RISCVSubtarget &STI)
@@ -1272,17 +1273,12 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12721273
VT, Custom);
12731274
setOperationAction(ISD::EXPERIMENTAL_VP_SPLICE, VT, Custom);
12741275
setOperationAction(ISD::EXPERIMENTAL_VP_REVERSE, VT, Custom);
1276+
setOperationAction(ISD::EXPERIMENTAL_VP_SPLAT, VT, Custom);
12751277

12761278
setOperationAction(ISD::FCOPYSIGN, VT, Legal);
1279+
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
12771280
setOperationAction(ZvfbfaVPOps, VT, Custom);
12781281

1279-
MVT EltVT = VT.getVectorElementType();
1280-
if (isTypeLegal(EltVT))
1281-
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT}, VT,
1282-
Custom);
1283-
else
1284-
setOperationAction({ISD::SPLAT_VECTOR, ISD::EXPERIMENTAL_VP_SPLAT},
1285-
EltVT, Custom);
12861282
setOperationAction({ISD::LOAD, ISD::STORE, ISD::MLOAD, ISD::MSTORE,
12871283
ISD::MGATHER, ISD::MSCATTER, ISD::VP_LOAD,
12881284
ISD::VP_STORE, ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
@@ -4870,7 +4866,7 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
48704866

48714867
if (VT.isFloatingPoint()) {
48724868
if ((EltVT == MVT::f16 && !Subtarget.hasStdExtZvfh()) ||
4873-
EltVT == MVT::bf16) {
4869+
(EltVT == MVT::bf16 && !Subtarget.hasVInstructionsBF16())) {
48744870
if ((EltVT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) ||
48754871
(EltVT == MVT::f16 && Subtarget.hasStdExtZfhmin()))
48764872
Scalar = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Scalar);

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp-splat-bf16.ll

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN
44
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zfbfmin,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZFBFMIN-ZVFBFMIN
55
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+zvfbfmin -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFMIN
6+
; RUN: llc -mtriple=riscv32 -target-abi=ilp32d -mattr=+v,+experimental-zvfbfa -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFA
7+
; RUN: llc -mtriple=riscv64 -target-abi=lp64d -mattr=+v,+experimental-zvfbfa -verify-machineinstrs < %s | FileCheck %s --check-prefixes=ZVFBFA
68

79
define <8 x bfloat> @splat_v8bf16(ptr %x, bfloat %y) {
810
; ZFBFMIN-ZVFBFMIN-LABEL: splat_v8bf16:
@@ -18,6 +20,12 @@ define <8 x bfloat> @splat_v8bf16(ptr %x, bfloat %y) {
1820
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
1921
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
2022
; ZVFBFMIN-NEXT: ret
23+
;
24+
; ZVFBFA-LABEL: splat_v8bf16:
25+
; ZVFBFA: # %bb.0:
26+
; ZVFBFA-NEXT: vsetvli a0, zero, e16alt, m1, ta, ma
27+
; ZVFBFA-NEXT: vfmv.v.f v8, fa0
28+
; ZVFBFA-NEXT: ret
2129
%a = insertelement <8 x bfloat> poison, bfloat %y, i32 0
2230
%b = shufflevector <8 x bfloat> %a, <8 x bfloat> poison, <8 x i32> zeroinitializer
2331
ret <8 x bfloat> %b
@@ -37,6 +45,12 @@ define <16 x bfloat> @splat_16bf16(ptr %x, bfloat %y) {
3745
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
3846
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
3947
; ZVFBFMIN-NEXT: ret
48+
;
49+
; ZVFBFA-LABEL: splat_16bf16:
50+
; ZVFBFA: # %bb.0:
51+
; ZVFBFA-NEXT: vsetvli a0, zero, e16alt, m2, ta, ma
52+
; ZVFBFA-NEXT: vfmv.v.f v8, fa0
53+
; ZVFBFA-NEXT: ret
4054
%a = insertelement <16 x bfloat> poison, bfloat %y, i32 0
4155
%b = shufflevector <16 x bfloat> %a, <16 x bfloat> poison, <16 x i32> zeroinitializer
4256
ret <16 x bfloat> %b
@@ -58,6 +72,12 @@ define <64 x bfloat> @splat_64bf16(ptr %x, bfloat %y) {
5872
; ZVFBFMIN-NEXT: vsetvli zero, a1, e16, m8, ta, ma
5973
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
6074
; ZVFBFMIN-NEXT: ret
75+
;
76+
; ZVFBFA-LABEL: splat_64bf16:
77+
; ZVFBFA: # %bb.0:
78+
; ZVFBFA-NEXT: vsetvli a0, zero, e16alt, m8, ta, ma
79+
; ZVFBFA-NEXT: vfmv.v.f v8, fa0
80+
; ZVFBFA-NEXT: ret
6181
%a = insertelement <64 x bfloat> poison, bfloat %y, i32 0
6282
%b = shufflevector <64 x bfloat> %a, <64 x bfloat> poison, <64 x i32> zeroinitializer
6383
ret <64 x bfloat> %b
@@ -75,6 +95,12 @@ define <8 x bfloat> @splat_zero_v8bf16(ptr %x) {
7595
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
7696
; ZVFBFMIN-NEXT: vmv.v.i v8, 0
7797
; ZVFBFMIN-NEXT: ret
98+
;
99+
; ZVFBFA-LABEL: splat_zero_v8bf16:
100+
; ZVFBFA: # %bb.0:
101+
; ZVFBFA-NEXT: vsetvli a0, zero, e16, m1, ta, ma
102+
; ZVFBFA-NEXT: vmv.v.i v8, 0
103+
; ZVFBFA-NEXT: ret
78104
ret <8 x bfloat> splat (bfloat 0.0)
79105
}
80106

@@ -90,6 +116,12 @@ define <16 x bfloat> @splat_zero_16bf16(ptr %x) {
90116
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
91117
; ZVFBFMIN-NEXT: vmv.v.i v8, 0
92118
; ZVFBFMIN-NEXT: ret
119+
;
120+
; ZVFBFA-LABEL: splat_zero_16bf16:
121+
; ZVFBFA: # %bb.0:
122+
; ZVFBFA-NEXT: vsetvli a0, zero, e16, m2, ta, ma
123+
; ZVFBFA-NEXT: vmv.v.i v8, 0
124+
; ZVFBFA-NEXT: ret
93125
ret <16 x bfloat> splat (bfloat 0.0)
94126
}
95127

@@ -107,6 +139,13 @@ define <8 x bfloat> @splat_negzero_v8bf16(ptr %x) {
107139
; ZVFBFMIN-NEXT: vsetivli zero, 8, e16, m1, ta, ma
108140
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
109141
; ZVFBFMIN-NEXT: ret
142+
;
143+
; ZVFBFA-LABEL: splat_negzero_v8bf16:
144+
; ZVFBFA: # %bb.0:
145+
; ZVFBFA-NEXT: lui a0, 1048568
146+
; ZVFBFA-NEXT: vsetvli a1, zero, e16, m1, ta, ma
147+
; ZVFBFA-NEXT: vmv.v.x v8, a0
148+
; ZVFBFA-NEXT: ret
110149
ret <8 x bfloat> splat (bfloat -0.0)
111150
}
112151

@@ -124,5 +163,12 @@ define <16 x bfloat> @splat_negzero_16bf16(ptr %x) {
124163
; ZVFBFMIN-NEXT: vsetivli zero, 16, e16, m2, ta, ma
125164
; ZVFBFMIN-NEXT: vmv.v.x v8, a0
126165
; ZVFBFMIN-NEXT: ret
166+
;
167+
; ZVFBFA-LABEL: splat_negzero_16bf16:
168+
; ZVFBFA: # %bb.0:
169+
; ZVFBFA-NEXT: lui a0, 1048568
170+
; ZVFBFA-NEXT: vsetvli a1, zero, e16, m2, ta, ma
171+
; ZVFBFA-NEXT: vmv.v.x v8, a0
172+
; ZVFBFA-NEXT: ret
127173
ret <16 x bfloat> splat (bfloat -0.0)
128174
}

0 commit comments

Comments
 (0)