Skip to content

Commit cc9d3f2

Browse files
[AArch64][GlobalISel] Improve lowering of vector fp16 fptrunc (#163398)
This commit improves the lowering of vectors of fp16 when truncating and (previously) extending. Truncating has to be handled in a specific way to avoid double rounding.
1 parent 1cea4a0 commit cc9d3f2

File tree

8 files changed

+170
-186
lines changed

8 files changed

+170
-186
lines changed

llvm/lib/Target/AArch64/AArch64InstrGISel.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ def G_VLSHR : AArch64GenericInstruction {
149149
let hasSideEffects = 0;
150150
}
151151

152+
// Float truncation using round to odd
153+
def G_FPTRUNC_ODD : AArch64GenericInstruction {
154+
let OutOperandList = (outs type0:$dst);
155+
let InOperandList = (ins type1:$src);
156+
let hasSideEffects = false;
157+
}
158+
152159
// Represents an integer to FP conversion on the FPR bank.
153160
def G_SITOF : AArch64GenericInstruction {
154161
let OutOperandList = (outs type0:$dst);
@@ -297,6 +304,8 @@ def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, vector_extract>;
297304

298305
def : GINodeEquiv<G_AARCH64_PREFETCH, AArch64Prefetch>;
299306

307+
def : GINodeEquiv<G_FPTRUNC_ODD, AArch64fcvtxn_n>;
308+
300309
// These are patterns that we only use for GlobalISel via the importer.
301310
def : Pat<(f32 (fadd (vector_extract (v2f32 FPR64:$Rn), (i64 0)),
302311
(vector_extract (v2f32 FPR64:$Rn), (i64 1)))),

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
2222
#include "llvm/CodeGen/GlobalISel/Utils.h"
2323
#include "llvm/CodeGen/MachineInstr.h"
24+
#include "llvm/CodeGen/MachineInstrBuilder.h"
2425
#include "llvm/CodeGen/MachineRegisterInfo.h"
2526
#include "llvm/CodeGen/TargetOpcodes.h"
2627
#include "llvm/IR/DerivedTypes.h"
@@ -820,8 +821,17 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
820821
.legalFor(
821822
{{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
822823
.libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
823-
.clampNumElements(0, v4s16, v4s16)
824-
.clampNumElements(0, v2s32, v2s32)
824+
.moreElementsToNextPow2(1)
825+
.customIf([](const LegalityQuery &Q) {
826+
LLT DstTy = Q.Types[0];
827+
LLT SrcTy = Q.Types[1];
828+
return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
829+
SrcTy.getScalarSizeInBits() == 64 &&
830+
DstTy.getScalarSizeInBits() == 16;
831+
})
832+
// Clamp based on input
833+
.clampNumElements(1, v4s32, v4s32)
834+
.clampNumElements(1, v2s64, v2s64)
825835
.scalarize(0);
826836

827837
getActionDefinitionsBuilder(G_FPEXT)
@@ -1479,6 +1489,10 @@ bool AArch64LegalizerInfo::legalizeCustom(
14791489
return legalizeICMP(MI, MRI, MIRBuilder);
14801490
case TargetOpcode::G_BITCAST:
14811491
return legalizeBitcast(MI, Helper);
1492+
case TargetOpcode::G_FPTRUNC:
1493+
// In order to lower f16 to f64 properly, we need to use f32 as an
1494+
// intermediary
1495+
return legalizeFptrunc(MI, MIRBuilder, MRI);
14821496
}
14831497

14841498
llvm_unreachable("expected switch to return");
@@ -2416,3 +2430,80 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
24162430
MI.eraseFromParent();
24172431
return true;
24182432
}
2433+
2434+
bool AArch64LegalizerInfo::legalizeFptrunc(MachineInstr &MI,
2435+
MachineIRBuilder &MIRBuilder,
2436+
MachineRegisterInfo &MRI) const {
2437+
auto [Dst, DstTy, Src, SrcTy] = MI.getFirst2RegLLTs();
2438+
assert(SrcTy.isFixedVector() && isPowerOf2_32(SrcTy.getNumElements()) &&
2439+
"Expected a power of 2 elements");
2440+
2441+
LLT s16 = LLT::scalar(16);
2442+
LLT s32 = LLT::scalar(32);
2443+
LLT s64 = LLT::scalar(64);
2444+
LLT v2s16 = LLT::fixed_vector(2, s16);
2445+
LLT v4s16 = LLT::fixed_vector(4, s16);
2446+
LLT v2s32 = LLT::fixed_vector(2, s32);
2447+
LLT v4s32 = LLT::fixed_vector(4, s32);
2448+
LLT v2s64 = LLT::fixed_vector(2, s64);
2449+
2450+
SmallVector<Register> RegsToUnmergeTo;
2451+
SmallVector<Register> TruncOddDstRegs;
2452+
SmallVector<Register> RegsToMerge;
2453+
2454+
unsigned ElemCount = SrcTy.getNumElements();
2455+
2456+
// Find the biggest size chunks we can work with
2457+
int StepSize = ElemCount % 4 ? 2 : 4;
2458+
2459+
// If we have a power of 2 greater than 2, we need to first unmerge into
2460+
// enough pieces
2461+
if (ElemCount <= 2)
2462+
RegsToUnmergeTo.push_back(Src);
2463+
else {
2464+
for (unsigned i = 0; i < ElemCount / 2; ++i)
2465+
RegsToUnmergeTo.push_back(MRI.createGenericVirtualRegister(v2s64));
2466+
2467+
MIRBuilder.buildUnmerge(RegsToUnmergeTo, Src);
2468+
}
2469+
2470+
// Create all of the round-to-odd instructions and store them
2471+
for (auto SrcReg : RegsToUnmergeTo) {
2472+
Register Mid =
2473+
MIRBuilder.buildInstr(AArch64::G_FPTRUNC_ODD, {v2s32}, {SrcReg})
2474+
.getReg(0);
2475+
TruncOddDstRegs.push_back(Mid);
2476+
}
2477+
2478+
// Truncate 4s32 to 4s16 if we can to reduce instruction count, otherwise
2479+
// truncate 2s32 to 2s16.
2480+
unsigned Index = 0;
2481+
for (unsigned LoopIter = 0; LoopIter < ElemCount / StepSize; ++LoopIter) {
2482+
if (StepSize == 4) {
2483+
Register ConcatDst =
2484+
MIRBuilder
2485+
.buildMergeLikeInstr(
2486+
{v4s32}, {TruncOddDstRegs[Index++], TruncOddDstRegs[Index++]})
2487+
.getReg(0);
2488+
2489+
RegsToMerge.push_back(
2490+
MIRBuilder.buildFPTrunc(v4s16, ConcatDst).getReg(0));
2491+
} else {
2492+
RegsToMerge.push_back(
2493+
MIRBuilder.buildFPTrunc(v2s16, TruncOddDstRegs[Index++]).getReg(0));
2494+
}
2495+
}
2496+
2497+
// If there is only one register, replace the destination
2498+
if (RegsToMerge.size() == 1) {
2499+
MRI.replaceRegWith(Dst, RegsToMerge.pop_back_val());
2500+
MI.eraseFromParent();
2501+
return true;
2502+
}
2503+
2504+
// Merge the rest of the instructions & replace the register
2505+
Register Fin = MIRBuilder.buildMergeLikeInstr(DstTy, RegsToMerge).getReg(0);
2506+
MRI.replaceRegWith(Dst, Fin);
2507+
MI.eraseFromParent();
2508+
return true;
2509+
}

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
6767
bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
6868
bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
6969
bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
70+
bool legalizeFptrunc(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
71+
MachineRegisterInfo &MRI) const;
7072
const AArch64Subtarget *ST;
7173
};
7274
} // End llvm namespace.

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,8 @@
578578
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
579579
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
580580
# DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
581-
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
582-
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
581+
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
582+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
583583
# DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
584584
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
585585
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected

llvm/test/CodeGen/AArch64/arm64-fp128.ll

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
11971197
;
11981198
; CHECK-GI-LABEL: vec_round_f16:
11991199
; CHECK-GI: // %bb.0:
1200-
; CHECK-GI-NEXT: sub sp, sp, #64
1201-
; CHECK-GI-NEXT: str x30, [sp, #48] // 8-byte Spill
1202-
; CHECK-GI-NEXT: .cfi_def_cfa_offset 64
1200+
; CHECK-GI-NEXT: sub sp, sp, #48
1201+
; CHECK-GI-NEXT: str x30, [sp, #32] // 8-byte Spill
1202+
; CHECK-GI-NEXT: .cfi_def_cfa_offset 48
12031203
; CHECK-GI-NEXT: .cfi_offset w30, -16
1204-
; CHECK-GI-NEXT: mov v2.d[0], x8
12051204
; CHECK-GI-NEXT: str q1, [sp] // 16-byte Spill
1206-
; CHECK-GI-NEXT: mov v2.d[1], x8
1207-
; CHECK-GI-NEXT: str q2, [sp, #32] // 16-byte Spill
12081205
; CHECK-GI-NEXT: bl __trunctfhf2
12091206
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
12101207
; CHECK-GI-NEXT: str q0, [sp, #16] // 16-byte Spill
12111208
; CHECK-GI-NEXT: ldr q0, [sp] // 16-byte Reload
12121209
; CHECK-GI-NEXT: bl __trunctfhf2
1210+
; CHECK-GI-NEXT: ldr q1, [sp, #16] // 16-byte Reload
12131211
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
1214-
; CHECK-GI-NEXT: str q0, [sp] // 16-byte Spill
1215-
; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Reload
1216-
; CHECK-GI-NEXT: bl __trunctfhf2
1217-
; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Reload
1218-
; CHECK-GI-NEXT: bl __trunctfhf2
1219-
; CHECK-GI-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
1220-
; CHECK-GI-NEXT: ldr x30, [sp, #48] // 8-byte Reload
1221-
; CHECK-GI-NEXT: mov v0.h[1], v1.h[0]
1222-
; CHECK-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
1223-
; CHECK-GI-NEXT: add sp, sp, #64
1212+
; CHECK-GI-NEXT: ldr x30, [sp, #32] // 8-byte Reload
1213+
; CHECK-GI-NEXT: mov v1.h[1], v0.h[0]
1214+
; CHECK-GI-NEXT: fmov d0, d1
1215+
; CHECK-GI-NEXT: add sp, sp, #48
12241216
; CHECK-GI-NEXT: ret
12251217
%dst = fptrunc <2 x fp128> %val to <2 x half>
12261218
ret <2 x half> %dst

llvm/test/CodeGen/AArch64/fp16-v4-instructions.ll

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -170,47 +170,12 @@ define <4 x half> @s_to_h(<4 x float> %a) {
170170
}
171171

172172
define <4 x half> @d_to_h(<4 x double> %a) {
173-
; CHECK-CVT-SD-LABEL: d_to_h:
174-
; CHECK-CVT-SD: // %bb.0:
175-
; CHECK-CVT-SD-NEXT: fcvtxn v0.2s, v0.2d
176-
; CHECK-CVT-SD-NEXT: fcvtxn2 v0.4s, v1.2d
177-
; CHECK-CVT-SD-NEXT: fcvtn v0.4h, v0.4s
178-
; CHECK-CVT-SD-NEXT: ret
179-
;
180-
; CHECK-FP16-SD-LABEL: d_to_h:
181-
; CHECK-FP16-SD: // %bb.0:
182-
; CHECK-FP16-SD-NEXT: fcvtxn v0.2s, v0.2d
183-
; CHECK-FP16-SD-NEXT: fcvtxn2 v0.4s, v1.2d
184-
; CHECK-FP16-SD-NEXT: fcvtn v0.4h, v0.4s
185-
; CHECK-FP16-SD-NEXT: ret
186-
;
187-
; CHECK-CVT-GI-LABEL: d_to_h:
188-
; CHECK-CVT-GI: // %bb.0:
189-
; CHECK-CVT-GI-NEXT: mov d2, v0.d[1]
190-
; CHECK-CVT-GI-NEXT: fcvt h0, d0
191-
; CHECK-CVT-GI-NEXT: mov d3, v1.d[1]
192-
; CHECK-CVT-GI-NEXT: fcvt h1, d1
193-
; CHECK-CVT-GI-NEXT: fcvt h2, d2
194-
; CHECK-CVT-GI-NEXT: mov v0.h[1], v2.h[0]
195-
; CHECK-CVT-GI-NEXT: fcvt h2, d3
196-
; CHECK-CVT-GI-NEXT: mov v0.h[2], v1.h[0]
197-
; CHECK-CVT-GI-NEXT: mov v0.h[3], v2.h[0]
198-
; CHECK-CVT-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
199-
; CHECK-CVT-GI-NEXT: ret
200-
;
201-
; CHECK-FP16-GI-LABEL: d_to_h:
202-
; CHECK-FP16-GI: // %bb.0:
203-
; CHECK-FP16-GI-NEXT: mov d2, v0.d[1]
204-
; CHECK-FP16-GI-NEXT: fcvt h0, d0
205-
; CHECK-FP16-GI-NEXT: mov d3, v1.d[1]
206-
; CHECK-FP16-GI-NEXT: fcvt h1, d1
207-
; CHECK-FP16-GI-NEXT: fcvt h2, d2
208-
; CHECK-FP16-GI-NEXT: mov v0.h[1], v2.h[0]
209-
; CHECK-FP16-GI-NEXT: fcvt h2, d3
210-
; CHECK-FP16-GI-NEXT: mov v0.h[2], v1.h[0]
211-
; CHECK-FP16-GI-NEXT: mov v0.h[3], v2.h[0]
212-
; CHECK-FP16-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
213-
; CHECK-FP16-GI-NEXT: ret
173+
; CHECK-LABEL: d_to_h:
174+
; CHECK: // %bb.0:
175+
; CHECK-NEXT: fcvtxn v0.2s, v0.2d
176+
; CHECK-NEXT: fcvtxn2 v0.4s, v1.2d
177+
; CHECK-NEXT: fcvtn v0.4h, v0.4s
178+
; CHECK-NEXT: ret
214179
%1 = fptrunc <4 x double> %a to <4 x half>
215180
ret <4 x half> %1
216181
}

llvm/test/CodeGen/AArch64/fp16-v8-instructions.ll

Lines changed: 9 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -176,71 +176,15 @@ define <8 x half> @s_to_h(<8 x float> %a) {
176176
}
177177

178178
define <8 x half> @d_to_h(<8 x double> %a) {
179-
; CHECK-CVT-SD-LABEL: d_to_h:
180-
; CHECK-CVT-SD: // %bb.0:
181-
; CHECK-CVT-SD-NEXT: fcvtxn v0.2s, v0.2d
182-
; CHECK-CVT-SD-NEXT: fcvtxn v2.2s, v2.2d
183-
; CHECK-CVT-SD-NEXT: fcvtxn2 v0.4s, v1.2d
184-
; CHECK-CVT-SD-NEXT: fcvtxn2 v2.4s, v3.2d
185-
; CHECK-CVT-SD-NEXT: fcvtn v0.4h, v0.4s
186-
; CHECK-CVT-SD-NEXT: fcvtn2 v0.8h, v2.4s
187-
; CHECK-CVT-SD-NEXT: ret
188-
;
189-
; CHECK-FP16-SD-LABEL: d_to_h:
190-
; CHECK-FP16-SD: // %bb.0:
191-
; CHECK-FP16-SD-NEXT: fcvtxn v0.2s, v0.2d
192-
; CHECK-FP16-SD-NEXT: fcvtxn v2.2s, v2.2d
193-
; CHECK-FP16-SD-NEXT: fcvtxn2 v0.4s, v1.2d
194-
; CHECK-FP16-SD-NEXT: fcvtxn2 v2.4s, v3.2d
195-
; CHECK-FP16-SD-NEXT: fcvtn v0.4h, v0.4s
196-
; CHECK-FP16-SD-NEXT: fcvtn2 v0.8h, v2.4s
197-
; CHECK-FP16-SD-NEXT: ret
198-
;
199-
; CHECK-CVT-GI-LABEL: d_to_h:
200-
; CHECK-CVT-GI: // %bb.0:
201-
; CHECK-CVT-GI-NEXT: mov d4, v0.d[1]
202-
; CHECK-CVT-GI-NEXT: fcvt h0, d0
203-
; CHECK-CVT-GI-NEXT: mov d5, v1.d[1]
204-
; CHECK-CVT-GI-NEXT: fcvt h1, d1
205-
; CHECK-CVT-GI-NEXT: fcvt h4, d4
206-
; CHECK-CVT-GI-NEXT: mov v0.h[1], v4.h[0]
207-
; CHECK-CVT-GI-NEXT: fcvt h4, d5
208-
; CHECK-CVT-GI-NEXT: mov v0.h[2], v1.h[0]
209-
; CHECK-CVT-GI-NEXT: mov d1, v2.d[1]
210-
; CHECK-CVT-GI-NEXT: fcvt h2, d2
211-
; CHECK-CVT-GI-NEXT: mov v0.h[3], v4.h[0]
212-
; CHECK-CVT-GI-NEXT: fcvt h1, d1
213-
; CHECK-CVT-GI-NEXT: mov v0.h[4], v2.h[0]
214-
; CHECK-CVT-GI-NEXT: mov d2, v3.d[1]
215-
; CHECK-CVT-GI-NEXT: fcvt h3, d3
216-
; CHECK-CVT-GI-NEXT: mov v0.h[5], v1.h[0]
217-
; CHECK-CVT-GI-NEXT: fcvt h1, d2
218-
; CHECK-CVT-GI-NEXT: mov v0.h[6], v3.h[0]
219-
; CHECK-CVT-GI-NEXT: mov v0.h[7], v1.h[0]
220-
; CHECK-CVT-GI-NEXT: ret
221-
;
222-
; CHECK-FP16-GI-LABEL: d_to_h:
223-
; CHECK-FP16-GI: // %bb.0:
224-
; CHECK-FP16-GI-NEXT: mov d4, v0.d[1]
225-
; CHECK-FP16-GI-NEXT: fcvt h0, d0
226-
; CHECK-FP16-GI-NEXT: mov d5, v1.d[1]
227-
; CHECK-FP16-GI-NEXT: fcvt h1, d1
228-
; CHECK-FP16-GI-NEXT: fcvt h4, d4
229-
; CHECK-FP16-GI-NEXT: mov v0.h[1], v4.h[0]
230-
; CHECK-FP16-GI-NEXT: fcvt h4, d5
231-
; CHECK-FP16-GI-NEXT: mov v0.h[2], v1.h[0]
232-
; CHECK-FP16-GI-NEXT: mov d1, v2.d[1]
233-
; CHECK-FP16-GI-NEXT: fcvt h2, d2
234-
; CHECK-FP16-GI-NEXT: mov v0.h[3], v4.h[0]
235-
; CHECK-FP16-GI-NEXT: fcvt h1, d1
236-
; CHECK-FP16-GI-NEXT: mov v0.h[4], v2.h[0]
237-
; CHECK-FP16-GI-NEXT: mov d2, v3.d[1]
238-
; CHECK-FP16-GI-NEXT: fcvt h3, d3
239-
; CHECK-FP16-GI-NEXT: mov v0.h[5], v1.h[0]
240-
; CHECK-FP16-GI-NEXT: fcvt h1, d2
241-
; CHECK-FP16-GI-NEXT: mov v0.h[6], v3.h[0]
242-
; CHECK-FP16-GI-NEXT: mov v0.h[7], v1.h[0]
243-
; CHECK-FP16-GI-NEXT: ret
179+
; CHECK-LABEL: d_to_h:
180+
; CHECK: // %bb.0:
181+
; CHECK-NEXT: fcvtxn v0.2s, v0.2d
182+
; CHECK-NEXT: fcvtxn v2.2s, v2.2d
183+
; CHECK-NEXT: fcvtxn2 v0.4s, v1.2d
184+
; CHECK-NEXT: fcvtxn2 v2.4s, v3.2d
185+
; CHECK-NEXT: fcvtn v0.4h, v0.4s
186+
; CHECK-NEXT: fcvtn2 v0.8h, v2.4s
187+
; CHECK-NEXT: ret
244188
%1 = fptrunc <8 x double> %a to <8 x half>
245189
ret <8 x half> %1
246190
}

0 commit comments

Comments
 (0)