Skip to content

Commit 86bd6c8

Browse files
authored
Add initial f16 and f128 support to the x64 backend (bytecodealliance#9045)
1 parent 12fc764 commit 86bd6c8

20 files changed

+969
-44
lines changed

cranelift/codegen/src/isa/x64/abi.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,12 @@ impl ABIMachineSpec for X64ABIMachineSpec {
123123
// extension annotations. Additionally, handling extension attributes this way allows clif
124124
// functions that use them with the Winch calling convention to interact successfully with
125125
// testing infrastructure.
126+
// The results are also not packed if any of the types are `f16`. This is to simplify the
127+
// implementation of `Inst::load`/`Inst::store` (which would otherwise require multiple
128+
// instructions), and doesn't affect Winch itself as Winch doesn't support `f16` at all.
126129
let uses_extension = params
127130
.iter()
128-
.any(|p| p.extension != ir::ArgumentExtension::None);
131+
.any(|p| p.extension != ir::ArgumentExtension::None || p.value_type == types::F16);
129132

130133
for (ix, param) in params.iter().enumerate() {
131134
let last_param = ix == params.len() - 1;
@@ -169,13 +172,23 @@ impl ABIMachineSpec for X64ABIMachineSpec {
169172
// https://godbolt.org/z/PhG3ob
170173

171174
if param.value_type.bits() > 64
172-
&& !param.value_type.is_vector()
175+
&& !(param.value_type.is_vector() || param.value_type.is_float())
173176
&& !flags.enable_llvm_abi_extensions()
174177
{
175178
panic!(
176179
"i128 args/return values not supported unless LLVM ABI extensions are enabled"
177180
);
178181
}
182+
// As MSVC doesn't support f16/f128 there is no standard way to pass/return them with
183+
// the Windows ABI. LLVM passes/returns them in XMM registers.
184+
if matches!(param.value_type, types::F16 | types::F128)
185+
&& is_fastcall
186+
&& !flags.enable_llvm_abi_extensions()
187+
{
188+
panic!(
189+
"f16/f128 args/return values not supported for windows_fastcall unless LLVM ABI extensions are enabled"
190+
);
191+
}
179192

180193
// Windows fastcall dictates that `__m128i` parameters to a function
181194
// are passed indirectly as pointers, so handle that as a special
@@ -410,12 +423,20 @@ impl ABIMachineSpec for X64ABIMachineSpec {
410423
// bits as well -- see `Inst::store()`).
411424
let ty = match ty {
412425
types::I8 | types::I16 | types::I32 => types::I64,
426+
// Stack slots are always at least 8 bytes, so it's fine to load 4 bytes instead of only
427+
// two.
428+
types::F16 => types::F32,
413429
_ => ty,
414430
};
415431
Inst::load(ty, mem, into_reg, ExtKind::None)
416432
}
417433

418434
fn gen_store_stack(mem: StackAMode, from_reg: Reg, ty: Type) -> Self::I {
435+
let ty = match ty {
436+
// See `gen_load_stack`.
437+
types::F16 => types::F32,
438+
_ => ty,
439+
};
419440
Inst::store(ty, from_reg, mem)
420441
}
421442

@@ -502,6 +523,11 @@ impl ABIMachineSpec for X64ABIMachineSpec {
502523
}
503524

504525
fn gen_store_base_offset(base: Reg, offset: i32, from_reg: Reg, ty: Type) -> Self::I {
526+
let ty = match ty {
527+
// See `gen_load_stack`.
528+
types::F16 => types::F32,
529+
_ => ty,
530+
};
505531
let mem = Amode::imm_reg(offset, base);
506532
Inst::store(ty, from_reg, mem)
507533
}

cranelift/codegen/src/isa/x64/inst.isle

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,7 @@
16441644
(rule (put_in_gpr val)
16451645
(if-let (value_type ty) val)
16461646
(if-let (type_register_class (RegisterClass.Xmm)) ty)
1647-
(bitcast_xmm_to_gpr ty (xmm_new (put_in_reg val))))
1647+
(bitcast_xmm_to_gpr (ty_bits ty) (xmm_new (put_in_reg val))))
16481648

16491649
;; Put a value into a `GprMem`.
16501650
;;
@@ -2252,8 +2252,10 @@
22522252

22532253
;; Performs an xor operation of the two operands specified.
22542254
(decl x64_xor_vector (Type Xmm XmmMem) Xmm)
2255+
(rule 1 (x64_xor_vector $F16 x y) (x64_xorps x y))
22552256
(rule 1 (x64_xor_vector $F32 x y) (x64_xorps x y))
22562257
(rule 1 (x64_xor_vector $F64 x y) (x64_xorpd x y))
2258+
(rule 1 (x64_xor_vector $F128 x y) (x64_xorps x y))
22572259
(rule 1 (x64_xor_vector $F32X4 x y) (x64_xorps x y))
22582260
(rule 1 (x64_xor_vector $F64X2 x y) (x64_xorpd x y))
22592261
(rule 0 (x64_xor_vector (multi_lane _ _) x y) (x64_pxor x y))
@@ -2304,6 +2306,9 @@
23042306
(rule 2 (x64_load $F64 addr _ext_kind)
23052307
(x64_movsd_load addr))
23062308

2309+
(rule 2 (x64_load $F128 addr _ext_kind)
2310+
(x64_movdqu_load addr))
2311+
23072312
(rule 2 (x64_load $F32X4 addr _ext_kind)
23082313
(x64_movups_load addr))
23092314

@@ -2719,6 +2724,10 @@
27192724
(_ Unit (emit (MInst.Imm size simm64 dst))))
27202725
dst))
27212726

2727+
;; `f16` immediates.
2728+
(rule 2 (imm $F16 (u64_nonzero bits))
2729+
(bitcast_gpr_to_xmm 16 (imm $I16 bits)))
2730+
27222731
;; `f32` immediates.
27232732
(rule 2 (imm $F32 (u64_nonzero bits))
27242733
(x64_movd_to_xmm (imm $I32 bits)))
@@ -2746,6 +2755,9 @@
27462755
(rule 0 (imm ty @ (multi_lane _bits _lanes) 0)
27472756
(xmm_to_reg (xmm_zero ty)))
27482757

2758+
;; Special case for `f16` zero immediates
2759+
(rule 2 (imm ty @ $F16 (u64_zero)) (xmm_zero ty))
2760+
27492761
;; Special case for `f32` zero immediates
27502762
(rule 2 (imm ty @ $F32 (u64_zero)) (xmm_zero ty))
27512763

@@ -5022,18 +5034,30 @@
50225034

50235035
;;;; Casting ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
50245036

5025-
(decl bitcast_xmm_to_gpr (Type Xmm) Gpr)
5026-
(rule (bitcast_xmm_to_gpr $F32 src)
5037+
(decl bitcast_xmm_to_gpr (u8 Xmm) Gpr)
5038+
(rule (bitcast_xmm_to_gpr 16 src)
5039+
(x64_pextrw src 0))
5040+
(rule (bitcast_xmm_to_gpr 32 src)
50275041
(x64_movd_to_gpr src))
5028-
(rule (bitcast_xmm_to_gpr $F64 src)
5042+
(rule (bitcast_xmm_to_gpr 64 src)
50295043
(x64_movq_to_gpr src))
50305044

5031-
(decl bitcast_gpr_to_xmm (Type Gpr) Xmm)
5032-
(rule (bitcast_gpr_to_xmm $I32 src)
5045+
(decl bitcast_xmm_to_gprs (Xmm) ValueRegs)
5046+
(rule (bitcast_xmm_to_gprs src)
5047+
(value_regs (x64_movq_to_gpr src) (x64_movq_to_gpr (x64_pshufd src 0b11101110))))
5048+
5049+
(decl bitcast_gpr_to_xmm (u8 Gpr) Xmm)
5050+
(rule (bitcast_gpr_to_xmm 16 src)
5051+
(x64_pinsrw (xmm_uninit_value) src 0))
5052+
(rule (bitcast_gpr_to_xmm 32 src)
50335053
(x64_movd_to_xmm src))
5034-
(rule (bitcast_gpr_to_xmm $I64 src)
5054+
(rule (bitcast_gpr_to_xmm 64 src)
50355055
(x64_movq_to_xmm src))
50365056

5057+
(decl bitcast_gprs_to_xmm (ValueRegs) Xmm)
5058+
(rule (bitcast_gprs_to_xmm src)
5059+
(x64_punpcklqdq (x64_movq_to_xmm (value_regs_get_gpr src 0)) (x64_movq_to_xmm (value_regs_get_gpr src 1))))
5060+
50375061
;;;; Stack Addresses ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
50385062

50395063
(decl stack_addr_impl (StackSlot Offset32) Gpr)

cranelift/codegen/src/isa/x64/inst/emit.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,10 +1428,11 @@ pub(crate) fn emit(
14281428
let op = match *ty {
14291429
types::F64 => SseOpcode::Movsd,
14301430
types::F32 => SseOpcode::Movsd,
1431+
types::F16 => SseOpcode::Movsd,
14311432
types::F32X4 => SseOpcode::Movaps,
14321433
types::F64X2 => SseOpcode::Movapd,
14331434
ty => {
1434-
debug_assert!(ty.is_vector() && ty.bytes() == 16);
1435+
debug_assert!((ty.is_float() || ty.is_vector()) && ty.bytes() == 16);
14351436
SseOpcode::Movdqa
14361437
}
14371438
};

cranelift/codegen/src/isa/x64/inst/mod.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -630,11 +630,12 @@ impl Inst {
630630
}
631631
RegClass::Float => {
632632
let opcode = match ty {
633+
types::F16 => panic!("loading a f16 requires multiple instructions"),
633634
types::F32 => SseOpcode::Movss,
634635
types::F64 => SseOpcode::Movsd,
635636
types::F32X4 => SseOpcode::Movups,
636637
types::F64X2 => SseOpcode::Movupd,
637-
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
638+
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
638639
_ => unimplemented!("unable to load type: {}", ty),
639640
};
640641
Inst::xmm_unary_rm_r(opcode, RegMem::mem(from_addr), to_reg)
@@ -650,11 +651,12 @@ impl Inst {
650651
RegClass::Int => Inst::mov_r_m(OperandSize::from_ty(ty), from_reg, to_addr),
651652
RegClass::Float => {
652653
let opcode = match ty {
654+
types::F16 => panic!("storing a f16 requires multiple instructions"),
653655
types::F32 => SseOpcode::Movss,
654656
types::F64 => SseOpcode::Movsd,
655657
types::F32X4 => SseOpcode::Movups,
656658
types::F64X2 => SseOpcode::Movupd,
657-
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqu,
659+
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqu,
658660
_ => unimplemented!("unable to store type: {}", ty),
659661
};
660662
Inst::xmm_mov_r_m(opcode, from_reg, to_addr)
@@ -1621,6 +1623,7 @@ impl PrettyPrint for Inst {
16211623
let suffix = match *ty {
16221624
types::F64 => "sd",
16231625
types::F32 => "ss",
1626+
types::F16 => "ss",
16241627
types::F32X4 => "aps",
16251628
types::F64X2 => "apd",
16261629
_ => "dqa",
@@ -2605,9 +2608,9 @@ impl MachInst for Inst {
26052608
// those, which may write more lanes that we need, but are specified to have
26062609
// zero-latency.
26072610
let opcode = match ty {
2608-
types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
2611+
types::F16 | types::F32 | types::F64 | types::F32X4 => SseOpcode::Movaps,
26092612
types::F64X2 => SseOpcode::Movapd,
2610-
_ if ty.is_vector() && ty.bits() == 128 => SseOpcode::Movdqa,
2613+
_ if (ty.is_float() || ty.is_vector()) && ty.bits() == 128 => SseOpcode::Movdqa,
26112614
_ => unimplemented!("unable to move type: {}", ty),
26122615
};
26132616
Inst::xmm_unary_rm_r(opcode, RegMem::reg(src_reg), dst_reg)
@@ -2628,8 +2631,10 @@ impl MachInst for Inst {
26282631
types::I64 => Ok((&[RegClass::Int], &[types::I64])),
26292632
types::R32 => panic!("32-bit reftype pointer should never be seen on x86-64"),
26302633
types::R64 => Ok((&[RegClass::Int], &[types::R64])),
2634+
types::F16 => Ok((&[RegClass::Float], &[types::F16])),
26312635
types::F32 => Ok((&[RegClass::Float], &[types::F32])),
26322636
types::F64 => Ok((&[RegClass::Float], &[types::F64])),
2637+
types::F128 => Ok((&[RegClass::Float], &[types::F128])),
26332638
types::I128 => Ok((&[RegClass::Int, RegClass::Int], &[types::I64, types::I64])),
26342639
_ if ty.is_vector() => {
26352640
assert!(ty.bits() <= 128);

0 commit comments

Comments
 (0)