Skip to content

Commit bfe0666

Browse files
[LLVM][CodeGen][SVE2] Implement nxvf64 fpround to nxvbf16. (#111012)
NOTE: SVE2 only because that is when FCVTX is available, which is required to perform the necessary two-step rounding.
1 parent a649e8f commit bfe0666

File tree

6 files changed

+229
-5
lines changed

6 files changed

+229
-5
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ static bool isMergePassthruOpcode(unsigned Opc) {
268268
case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
269269
case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
270270
case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
271+
case AArch64ISD::FCVTX_MERGE_PASSTHRU:
271272
case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
272273
case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
273274
case AArch64ISD::FSQRT_MERGE_PASSTHRU:
@@ -2652,6 +2653,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
26522653
MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
26532654
MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
26542655
MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
2656+
MAKE_CASE(AArch64ISD::FCVTX_MERGE_PASSTHRU)
26552657
MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
26562658
MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
26572659
MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
@@ -4416,6 +4418,19 @@ SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
44164418
// Set the quiet bit.
44174419
if (!DAG.isKnownNeverSNaN(SrcVal))
44184420
NaN = DAG.getNode(ISD::OR, DL, I32, Narrow, ImmV(0x400000));
4421+
} else if (SrcVT == MVT::nxv2f64 &&
4422+
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
4423+
// Round to float without introducing rounding errors and try again.
4424+
SDValue Pg = getPredicateForVector(DAG, DL, MVT::nxv2f32);
4425+
Narrow = DAG.getNode(AArch64ISD::FCVTX_MERGE_PASSTHRU, DL, MVT::nxv2f32,
4426+
Pg, SrcVal, DAG.getUNDEF(MVT::nxv2f32));
4427+
4428+
SmallVector<SDValue, 3> NewOps;
4429+
if (IsStrict)
4430+
NewOps.push_back(Op.getOperand(0));
4431+
NewOps.push_back(Narrow);
4432+
NewOps.push_back(Op.getOperand(IsStrict ? 2 : 1));
4433+
return DAG.getNode(Op.getOpcode(), DL, VT, NewOps, Op->getFlags());
44194434
} else
44204435
return SDValue();
44214436

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ enum NodeType : unsigned {
158158
FP_EXTEND_MERGE_PASSTHRU,
159159
UINT_TO_FP_MERGE_PASSTHRU,
160160
SINT_TO_FP_MERGE_PASSTHRU,
161+
FCVTX_MERGE_PASSTHRU,
161162
FCVTZU_MERGE_PASSTHRU,
162163
FCVTZS_MERGE_PASSTHRU,
163164
SIGN_EXTEND_INREG_MERGE_PASSTHRU,

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ def AArch64fcvtr_mt : SDNode<"AArch64ISD::FP_ROUND_MERGE_PASSTHRU", SDT_AArch64
357357
def AArch64fcvte_mt : SDNode<"AArch64ISD::FP_EXTEND_MERGE_PASSTHRU", SDT_AArch64FCVT>;
358358
def AArch64ucvtf_mt : SDNode<"AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU", SDT_AArch64FCVT>;
359359
def AArch64scvtf_mt : SDNode<"AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU", SDT_AArch64FCVT>;
360+
def AArch64fcvtx_mt : SDNode<"AArch64ISD::FCVTX_MERGE_PASSTHRU", SDT_AArch64FCVT>;
360361
def AArch64fcvtzu_mt : SDNode<"AArch64ISD::FCVTZU_MERGE_PASSTHRU", SDT_AArch64FCVT>;
361362
def AArch64fcvtzs_mt : SDNode<"AArch64ISD::FCVTZS_MERGE_PASSTHRU", SDT_AArch64FCVT>;
362363

@@ -3788,7 +3789,7 @@ let Predicates = [HasSVE2orSME, UseExperimentalZeroingPseudos] in {
37883789
let Predicates = [HasSVE2orSME] in {
37893790
// SVE2 floating-point convert precision
37903791
defm FCVTXNT_ZPmZ : sve2_fp_convert_down_odd_rounding_top<"fcvtxnt", "int_aarch64_sve_fcvtxnt">;
3791-
defm FCVTX_ZPmZ : sve2_fp_convert_down_odd_rounding<"fcvtx", "int_aarch64_sve_fcvtx">;
3792+
defm FCVTX_ZPmZ : sve2_fp_convert_down_odd_rounding<"fcvtx", "int_aarch64_sve_fcvtx", AArch64fcvtx_mt>;
37923793
defm FCVTNT_ZPmZ : sve2_fp_convert_down_narrow<"fcvtnt", "int_aarch64_sve_fcvtnt">;
37933794
defm FCVTLT_ZPmZ : sve2_fp_convert_up_long<"fcvtlt", "int_aarch64_sve_fcvtlt">;
37943795

llvm/lib/Target/AArch64/AArch64Subtarget.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,14 @@ class AArch64Subtarget final : public AArch64GenSubtargetInfo {
188188
(hasSMEFA64() || (!isStreaming() && !isStreamingCompatible()));
189189
}
190190

191-
/// Returns true if the target has access to either the full range of SVE instructions,
192-
/// or the streaming-compatible subset of SVE instructions.
191+
/// Returns true if the target has access to the streaming-compatible subset
192+
/// of SVE instructions.
193+
bool isStreamingSVEAvailable() const { return hasSME() && isStreaming(); }
194+
195+
/// Returns true if the target has access to either the full range of SVE
196+
/// instructions, or the streaming-compatible subset of SVE instructions.
193197
bool isSVEorStreamingSVEAvailable() const {
194-
return hasSVE() || (hasSME() && isStreaming());
198+
return hasSVE() || isStreamingSVEAvailable();
195199
}
196200

197201
unsigned getMinVectorRegisterBitWidth() const {

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3061,9 +3061,11 @@ multiclass sve2_fp_un_pred_zeroing_hsd<SDPatternOperator op> {
30613061
def : SVE_1_Op_PassthruZero_Pat<nxv2i64, op, nxv2i1, nxv2f64, !cast<Pseudo>(NAME # _D_ZERO)>;
30623062
}
30633063

3064-
multiclass sve2_fp_convert_down_odd_rounding<string asm, string op> {
3064+
multiclass sve2_fp_convert_down_odd_rounding<string asm, string op, SDPatternOperator ir_op = null_frag> {
30653065
def _DtoS : sve_fp_2op_p_zd<0b0001010, asm, ZPR64, ZPR32, ElementSizeD>;
3066+
30663067
def : SVE_3_Op_Pat<nxv4f32, !cast<SDPatternOperator>(op # _f32f64), nxv4f32, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _DtoS)>;
3068+
def : SVE_1_Op_Passthru_Pat<nxv2f32, ir_op, nxv2i1, nxv2f64, !cast<Instruction>(NAME # _DtoS)>;
30673069
}
30683070

30693071
//===----------------------------------------------------------------------===//
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mattr=+sve2 < %s | FileCheck %s --check-prefixes=NOBF16
3+
; RUN: llc -mattr=+sve2 --enable-no-nans-fp-math < %s | FileCheck %s --check-prefixes=NOBF16NNAN
4+
; RUN: llc -mattr=+sve2,+bf16 < %s | FileCheck %s --check-prefixes=BF16
5+
; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s --check-prefixes=BF16
6+
7+
target triple = "aarch64-unknown-linux-gnu"
8+
9+
define <vscale x 2 x bfloat> @fptrunc_nxv2f64_to_nxv2bf16(<vscale x 2 x double> %a) {
10+
; NOBF16-LABEL: fptrunc_nxv2f64_to_nxv2bf16:
11+
; NOBF16: // %bb.0:
12+
; NOBF16-NEXT: ptrue p0.d
13+
; NOBF16-NEXT: mov z1.s, #32767 // =0x7fff
14+
; NOBF16-NEXT: fcvtx z0.s, p0/m, z0.d
15+
; NOBF16-NEXT: lsr z2.s, z0.s, #16
16+
; NOBF16-NEXT: add z1.s, z0.s, z1.s
17+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
18+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
19+
; NOBF16-NEXT: and z2.s, z2.s, #0x1
20+
; NOBF16-NEXT: add z1.s, z2.s, z1.s
21+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z1.s
22+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
23+
; NOBF16-NEXT: ret
24+
;
25+
; NOBF16NNAN-LABEL: fptrunc_nxv2f64_to_nxv2bf16:
26+
; NOBF16NNAN: // %bb.0:
27+
; NOBF16NNAN-NEXT: ptrue p0.d
28+
; NOBF16NNAN-NEXT: mov z1.s, #32767 // =0x7fff
29+
; NOBF16NNAN-NEXT: fcvtx z0.s, p0/m, z0.d
30+
; NOBF16NNAN-NEXT: lsr z2.s, z0.s, #16
31+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z1.s
32+
; NOBF16NNAN-NEXT: and z2.s, z2.s, #0x1
33+
; NOBF16NNAN-NEXT: add z0.s, z2.s, z0.s
34+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
35+
; NOBF16NNAN-NEXT: ret
36+
;
37+
; BF16-LABEL: fptrunc_nxv2f64_to_nxv2bf16:
38+
; BF16: // %bb.0:
39+
; BF16-NEXT: ptrue p0.d
40+
; BF16-NEXT: fcvtx z0.s, p0/m, z0.d
41+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
42+
; BF16-NEXT: ret
43+
%res = fptrunc <vscale x 2 x double> %a to <vscale x 2 x bfloat>
44+
ret <vscale x 2 x bfloat> %res
45+
}
46+
47+
define <vscale x 4 x bfloat> @fptrunc_nxv4f64_to_nxv4bf16(<vscale x 4 x double> %a) {
48+
; NOBF16-LABEL: fptrunc_nxv4f64_to_nxv4bf16:
49+
; NOBF16: // %bb.0:
50+
; NOBF16-NEXT: ptrue p0.d
51+
; NOBF16-NEXT: mov z2.s, #32767 // =0x7fff
52+
; NOBF16-NEXT: fcvtx z1.s, p0/m, z1.d
53+
; NOBF16-NEXT: fcvtx z0.s, p0/m, z0.d
54+
; NOBF16-NEXT: lsr z3.s, z1.s, #16
55+
; NOBF16-NEXT: lsr z4.s, z0.s, #16
56+
; NOBF16-NEXT: add z5.s, z1.s, z2.s
57+
; NOBF16-NEXT: add z2.s, z0.s, z2.s
58+
; NOBF16-NEXT: fcmuo p1.s, p0/z, z1.s, z1.s
59+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
60+
; NOBF16-NEXT: orr z1.s, z1.s, #0x400000
61+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
62+
; NOBF16-NEXT: and z3.s, z3.s, #0x1
63+
; NOBF16-NEXT: and z4.s, z4.s, #0x1
64+
; NOBF16-NEXT: add z3.s, z3.s, z5.s
65+
; NOBF16-NEXT: add z2.s, z4.s, z2.s
66+
; NOBF16-NEXT: sel z1.s, p1, z1.s, z3.s
67+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z2.s
68+
; NOBF16-NEXT: lsr z1.s, z1.s, #16
69+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
70+
; NOBF16-NEXT: uzp1 z0.s, z0.s, z1.s
71+
; NOBF16-NEXT: ret
72+
;
73+
; NOBF16NNAN-LABEL: fptrunc_nxv4f64_to_nxv4bf16:
74+
; NOBF16NNAN: // %bb.0:
75+
; NOBF16NNAN-NEXT: ptrue p0.d
76+
; NOBF16NNAN-NEXT: mov z2.s, #32767 // =0x7fff
77+
; NOBF16NNAN-NEXT: fcvtx z1.s, p0/m, z1.d
78+
; NOBF16NNAN-NEXT: fcvtx z0.s, p0/m, z0.d
79+
; NOBF16NNAN-NEXT: lsr z3.s, z1.s, #16
80+
; NOBF16NNAN-NEXT: lsr z4.s, z0.s, #16
81+
; NOBF16NNAN-NEXT: add z1.s, z1.s, z2.s
82+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z2.s
83+
; NOBF16NNAN-NEXT: and z3.s, z3.s, #0x1
84+
; NOBF16NNAN-NEXT: and z4.s, z4.s, #0x1
85+
; NOBF16NNAN-NEXT: add z1.s, z3.s, z1.s
86+
; NOBF16NNAN-NEXT: add z0.s, z4.s, z0.s
87+
; NOBF16NNAN-NEXT: lsr z1.s, z1.s, #16
88+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
89+
; NOBF16NNAN-NEXT: uzp1 z0.s, z0.s, z1.s
90+
; NOBF16NNAN-NEXT: ret
91+
;
92+
; BF16-LABEL: fptrunc_nxv4f64_to_nxv4bf16:
93+
; BF16: // %bb.0:
94+
; BF16-NEXT: ptrue p0.d
95+
; BF16-NEXT: fcvtx z1.s, p0/m, z1.d
96+
; BF16-NEXT: fcvtx z0.s, p0/m, z0.d
97+
; BF16-NEXT: bfcvt z1.h, p0/m, z1.s
98+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
99+
; BF16-NEXT: uzp1 z0.s, z0.s, z1.s
100+
; BF16-NEXT: ret
101+
%res = fptrunc <vscale x 4 x double> %a to <vscale x 4 x bfloat>
102+
ret <vscale x 4 x bfloat> %res
103+
}
104+
105+
define <vscale x 8 x bfloat> @fptrunc_nxv8f64_to_nxv8bf16(<vscale x 8 x double> %a) {
106+
; NOBF16-LABEL: fptrunc_nxv8f64_to_nxv8bf16:
107+
; NOBF16: // %bb.0:
108+
; NOBF16-NEXT: ptrue p0.d
109+
; NOBF16-NEXT: mov z4.s, #32767 // =0x7fff
110+
; NOBF16-NEXT: fcvtx z3.s, p0/m, z3.d
111+
; NOBF16-NEXT: fcvtx z2.s, p0/m, z2.d
112+
; NOBF16-NEXT: fcvtx z1.s, p0/m, z1.d
113+
; NOBF16-NEXT: fcvtx z0.s, p0/m, z0.d
114+
; NOBF16-NEXT: lsr z5.s, z3.s, #16
115+
; NOBF16-NEXT: lsr z6.s, z2.s, #16
116+
; NOBF16-NEXT: lsr z7.s, z1.s, #16
117+
; NOBF16-NEXT: lsr z24.s, z0.s, #16
118+
; NOBF16-NEXT: add z25.s, z3.s, z4.s
119+
; NOBF16-NEXT: add z26.s, z2.s, z4.s
120+
; NOBF16-NEXT: add z27.s, z1.s, z4.s
121+
; NOBF16-NEXT: add z4.s, z0.s, z4.s
122+
; NOBF16-NEXT: fcmuo p1.s, p0/z, z3.s, z3.s
123+
; NOBF16-NEXT: and z5.s, z5.s, #0x1
124+
; NOBF16-NEXT: and z6.s, z6.s, #0x1
125+
; NOBF16-NEXT: and z7.s, z7.s, #0x1
126+
; NOBF16-NEXT: and z24.s, z24.s, #0x1
127+
; NOBF16-NEXT: fcmuo p2.s, p0/z, z2.s, z2.s
128+
; NOBF16-NEXT: fcmuo p3.s, p0/z, z1.s, z1.s
129+
; NOBF16-NEXT: fcmuo p0.s, p0/z, z0.s, z0.s
130+
; NOBF16-NEXT: orr z3.s, z3.s, #0x400000
131+
; NOBF16-NEXT: orr z2.s, z2.s, #0x400000
132+
; NOBF16-NEXT: add z5.s, z5.s, z25.s
133+
; NOBF16-NEXT: add z6.s, z6.s, z26.s
134+
; NOBF16-NEXT: add z7.s, z7.s, z27.s
135+
; NOBF16-NEXT: add z4.s, z24.s, z4.s
136+
; NOBF16-NEXT: orr z1.s, z1.s, #0x400000
137+
; NOBF16-NEXT: orr z0.s, z0.s, #0x400000
138+
; NOBF16-NEXT: sel z3.s, p1, z3.s, z5.s
139+
; NOBF16-NEXT: sel z2.s, p2, z2.s, z6.s
140+
; NOBF16-NEXT: sel z1.s, p3, z1.s, z7.s
141+
; NOBF16-NEXT: sel z0.s, p0, z0.s, z4.s
142+
; NOBF16-NEXT: lsr z3.s, z3.s, #16
143+
; NOBF16-NEXT: lsr z2.s, z2.s, #16
144+
; NOBF16-NEXT: lsr z1.s, z1.s, #16
145+
; NOBF16-NEXT: lsr z0.s, z0.s, #16
146+
; NOBF16-NEXT: uzp1 z2.s, z2.s, z3.s
147+
; NOBF16-NEXT: uzp1 z0.s, z0.s, z1.s
148+
; NOBF16-NEXT: uzp1 z0.h, z0.h, z2.h
149+
; NOBF16-NEXT: ret
150+
;
151+
; NOBF16NNAN-LABEL: fptrunc_nxv8f64_to_nxv8bf16:
152+
; NOBF16NNAN: // %bb.0:
153+
; NOBF16NNAN-NEXT: ptrue p0.d
154+
; NOBF16NNAN-NEXT: mov z4.s, #32767 // =0x7fff
155+
; NOBF16NNAN-NEXT: fcvtx z3.s, p0/m, z3.d
156+
; NOBF16NNAN-NEXT: fcvtx z2.s, p0/m, z2.d
157+
; NOBF16NNAN-NEXT: fcvtx z1.s, p0/m, z1.d
158+
; NOBF16NNAN-NEXT: fcvtx z0.s, p0/m, z0.d
159+
; NOBF16NNAN-NEXT: lsr z5.s, z3.s, #16
160+
; NOBF16NNAN-NEXT: lsr z6.s, z2.s, #16
161+
; NOBF16NNAN-NEXT: lsr z7.s, z1.s, #16
162+
; NOBF16NNAN-NEXT: lsr z24.s, z0.s, #16
163+
; NOBF16NNAN-NEXT: add z3.s, z3.s, z4.s
164+
; NOBF16NNAN-NEXT: add z2.s, z2.s, z4.s
165+
; NOBF16NNAN-NEXT: add z1.s, z1.s, z4.s
166+
; NOBF16NNAN-NEXT: add z0.s, z0.s, z4.s
167+
; NOBF16NNAN-NEXT: and z5.s, z5.s, #0x1
168+
; NOBF16NNAN-NEXT: and z6.s, z6.s, #0x1
169+
; NOBF16NNAN-NEXT: and z7.s, z7.s, #0x1
170+
; NOBF16NNAN-NEXT: and z24.s, z24.s, #0x1
171+
; NOBF16NNAN-NEXT: add z3.s, z5.s, z3.s
172+
; NOBF16NNAN-NEXT: add z2.s, z6.s, z2.s
173+
; NOBF16NNAN-NEXT: add z1.s, z7.s, z1.s
174+
; NOBF16NNAN-NEXT: add z0.s, z24.s, z0.s
175+
; NOBF16NNAN-NEXT: lsr z3.s, z3.s, #16
176+
; NOBF16NNAN-NEXT: lsr z2.s, z2.s, #16
177+
; NOBF16NNAN-NEXT: lsr z1.s, z1.s, #16
178+
; NOBF16NNAN-NEXT: lsr z0.s, z0.s, #16
179+
; NOBF16NNAN-NEXT: uzp1 z2.s, z2.s, z3.s
180+
; NOBF16NNAN-NEXT: uzp1 z0.s, z0.s, z1.s
181+
; NOBF16NNAN-NEXT: uzp1 z0.h, z0.h, z2.h
182+
; NOBF16NNAN-NEXT: ret
183+
;
184+
; BF16-LABEL: fptrunc_nxv8f64_to_nxv8bf16:
185+
; BF16: // %bb.0:
186+
; BF16-NEXT: ptrue p0.d
187+
; BF16-NEXT: fcvtx z3.s, p0/m, z3.d
188+
; BF16-NEXT: fcvtx z2.s, p0/m, z2.d
189+
; BF16-NEXT: fcvtx z1.s, p0/m, z1.d
190+
; BF16-NEXT: fcvtx z0.s, p0/m, z0.d
191+
; BF16-NEXT: bfcvt z3.h, p0/m, z3.s
192+
; BF16-NEXT: bfcvt z2.h, p0/m, z2.s
193+
; BF16-NEXT: bfcvt z1.h, p0/m, z1.s
194+
; BF16-NEXT: bfcvt z0.h, p0/m, z0.s
195+
; BF16-NEXT: uzp1 z2.s, z2.s, z3.s
196+
; BF16-NEXT: uzp1 z0.s, z0.s, z1.s
197+
; BF16-NEXT: uzp1 z0.h, z0.h, z2.h
198+
; BF16-NEXT: ret
199+
%res = fptrunc <vscale x 8 x double> %a to <vscale x 8 x bfloat>
200+
ret <vscale x 8 x bfloat> %res
201+
}

0 commit comments

Comments
 (0)