Skip to content

Commit bd3dcd3

Browse files
authored
x64: Add more fma instruction lowerings (bytecodealliance#5846)
The relaxed-simd proposal for WebAssembly adds a fused-multiply-add operation for `v128` types so I was poking around at Cranelift's existing support for its `fma` instruction. I was also poking around at the x86_64 ISA's offerings for the FMA operation and ended up with this PR that improves the lowering of the `fma` instruction on the x64 backend in a number of ways: * A libcall-based fallback is now provided for `f32x4` and `f64x2` types in preparation for eventual support of the relaxed-simd proposal. These encodings are horribly slow, but it's expected that if FMA semantics must be guaranteed then it's the best that can be done without the `fma` feature. Otherwise it'll be up to producers (e.g. Wasmtime embedders) whether wasm-level FMA operations should be FMA or multiply-then-add. * In addition to the existing `vfmadd213*` instructions opcodes were added for `vfmadd132*`. The `132` variant is selected based on which argument can have a sinkable load. * Any argument in the `fma` CLIF instruction can now have a `sinkable_load` and it'll generate a single FMA instruction. * All `vfnmadd*` opcodes were added as well. These are pattern-matched where one of the arguments to the CLIF instruction is an `fneg`. I opted to not add a new CLIF instruction here since it seemed like pattern matching was easy enough but I'm also not intimately familiar with the semantics here so if that's the preferred approach I can do that too.
1 parent d82ebcc commit bd3dcd3

File tree

9 files changed

+719
-78
lines changed

9 files changed

+719
-78
lines changed

cranelift/codegen/src/isa/x64/inst.isle

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,6 +1199,18 @@
11991199
Vfmadd213sd
12001200
Vfmadd213ps
12011201
Vfmadd213pd
1202+
Vfmadd132ss
1203+
Vfmadd132sd
1204+
Vfmadd132ps
1205+
Vfmadd132pd
1206+
Vfnmadd213ss
1207+
Vfnmadd213sd
1208+
Vfnmadd213ps
1209+
Vfnmadd213pd
1210+
Vfnmadd132ss
1211+
Vfnmadd132sd
1212+
Vfnmadd132ps
1213+
Vfnmadd132pd
12021214
Vcmpps
12031215
Vcmppd
12041216
Vpsrlw
@@ -1623,8 +1635,8 @@
16231635
(decl use_popcnt (bool) Type)
16241636
(extern extractor infallible use_popcnt use_popcnt)
16251637

1626-
(decl use_fma (bool) Type)
1627-
(extern extractor infallible use_fma use_fma)
1638+
(decl pure use_fma () bool)
1639+
(extern constructor use_fma use_fma)
16281640

16291641
(decl use_sse41 (bool) Type)
16301642
(extern extractor infallible use_sse41 use_sse41)
@@ -3598,34 +3610,33 @@
35983610
(_ Unit (emit (MInst.XmmRmRVex3 op src1 src2 src3 dst))))
35993611
dst))
36003612

3601-
;; Helper for creating `vfmadd213ss` instructions.
3602-
; TODO: This should have the (Xmm Xmm XmmMem) signature
3603-
; but we don't support VEX memory encodings yet
3604-
(decl x64_vfmadd213ss (Xmm Xmm Xmm) Xmm)
3605-
(rule (x64_vfmadd213ss x y z)
3606-
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213ss) x y z))
3607-
3608-
;; Helper for creating `vfmadd213sd` instructions.
3609-
; TODO: This should have the (Xmm Xmm XmmMem) signature
3610-
; but we don't support VEX memory encodings yet
3611-
(decl x64_vfmadd213sd (Xmm Xmm Xmm) Xmm)
3612-
(rule (x64_vfmadd213sd x y z)
3613-
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213sd) x y z))
3614-
3615-
;; Helper for creating `vfmadd213ps` instructions.
3616-
; TODO: This should have the (Xmm Xmm XmmMem) signature
3617-
; but we don't support VEX memory encodings yet
3618-
(decl x64_vfmadd213ps (Xmm Xmm Xmm) Xmm)
3619-
(rule (x64_vfmadd213ps x y z)
3620-
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213ps) x y z))
3621-
3622-
;; Helper for creating `vfmadd213pd` instructions.
3623-
; TODO: This should have the (Xmm Xmm XmmMem) signature
3624-
; but we don't support VEX memory encodings yet
3625-
(decl x64_vfmadd213pd (Xmm Xmm Xmm) Xmm)
3626-
(rule (x64_vfmadd213pd x y z)
3627-
(xmm_rmr_vex3 (AvxOpcode.Vfmadd213pd) x y z))
3628-
3613+
;; Helper for creating `vfmadd213*` instructions
3614+
(decl x64_vfmadd213 (Type Xmm Xmm XmmMem) Xmm)
3615+
(rule (x64_vfmadd213 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213ss) a b c))
3616+
(rule (x64_vfmadd213 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213sd) a b c))
3617+
(rule (x64_vfmadd213 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213ps) a b c))
3618+
(rule (x64_vfmadd213 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd213pd) a b c))
3619+
3620+
;; Helper for creating `vfmadd132*` instructions
3621+
(decl x64_vfmadd132 (Type Xmm Xmm XmmMem) Xmm)
3622+
(rule (x64_vfmadd132 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132ss) a b c))
3623+
(rule (x64_vfmadd132 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132sd) a b c))
3624+
(rule (x64_vfmadd132 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132ps) a b c))
3625+
(rule (x64_vfmadd132 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfmadd132pd) a b c))
3626+
3627+
;; Helper for creating `vfnmadd213*` instructions
3628+
(decl x64_vfnmadd213 (Type Xmm Xmm XmmMem) Xmm)
3629+
(rule (x64_vfnmadd213 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213ss) a b c))
3630+
(rule (x64_vfnmadd213 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213sd) a b c))
3631+
(rule (x64_vfnmadd213 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213ps) a b c))
3632+
(rule (x64_vfnmadd213 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd213pd) a b c))
3633+
3634+
;; Helper for creating `vfnmadd132*` instructions
3635+
(decl x64_vfnmadd132 (Type Xmm Xmm XmmMem) Xmm)
3636+
(rule (x64_vfnmadd132 $F32 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132ss) a b c))
3637+
(rule (x64_vfnmadd132 $F64 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132sd) a b c))
3638+
(rule (x64_vfnmadd132 $F32X4 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132ps) a b c))
3639+
(rule (x64_vfnmadd132 $F64X2 a b c) (xmm_rmr_vex3 (AvxOpcode.Vfnmadd132pd) a b c))
36293640

36303641
;; Helper for creating `sqrtss` instructions.
36313642
(decl x64_sqrtss (XmmMem) Xmm)

cranelift/codegen/src/isa/x64/inst/args.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1515,7 +1515,19 @@ impl AvxOpcode {
15151515
AvxOpcode::Vfmadd213ss
15161516
| AvxOpcode::Vfmadd213sd
15171517
| AvxOpcode::Vfmadd213ps
1518-
| AvxOpcode::Vfmadd213pd => smallvec![InstructionSet::FMA],
1518+
| AvxOpcode::Vfmadd213pd
1519+
| AvxOpcode::Vfmadd132ss
1520+
| AvxOpcode::Vfmadd132sd
1521+
| AvxOpcode::Vfmadd132ps
1522+
| AvxOpcode::Vfmadd132pd
1523+
| AvxOpcode::Vfnmadd213ss
1524+
| AvxOpcode::Vfnmadd213sd
1525+
| AvxOpcode::Vfnmadd213ps
1526+
| AvxOpcode::Vfnmadd213pd
1527+
| AvxOpcode::Vfnmadd132ss
1528+
| AvxOpcode::Vfnmadd132sd
1529+
| AvxOpcode::Vfnmadd132ps
1530+
| AvxOpcode::Vfnmadd132pd => smallvec![InstructionSet::FMA],
15191531
AvxOpcode::Vminps
15201532
| AvxOpcode::Vminpd
15211533
| AvxOpcode::Vmaxps

cranelift/codegen/src/isa/x64/inst/emit.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,32 +2281,46 @@ pub(crate) fn emit(
22812281
let dst = allocs.next(dst.to_reg().to_reg());
22822282
debug_assert_eq!(src1, dst);
22832283
let src2 = allocs.next(src2.to_reg());
2284-
let src3 = src3.clone().to_reg_mem().with_allocs(allocs);
2284+
let src3 = match src3.clone().to_reg_mem().with_allocs(allocs) {
2285+
RegMem::Reg { reg } => {
2286+
RegisterOrAmode::Register(reg.to_real_reg().unwrap().hw_enc().into())
2287+
}
2288+
RegMem::Mem { addr } => RegisterOrAmode::Amode(addr.finalize(state, sink)),
2289+
};
22852290

22862291
let (w, map, opcode) = match op {
2292+
AvxOpcode::Vfmadd132ss => (false, OpcodeMap::_0F38, 0x99),
22872293
AvxOpcode::Vfmadd213ss => (false, OpcodeMap::_0F38, 0xA9),
2294+
AvxOpcode::Vfnmadd132ss => (false, OpcodeMap::_0F38, 0x9D),
2295+
AvxOpcode::Vfnmadd213ss => (false, OpcodeMap::_0F38, 0xAD),
2296+
AvxOpcode::Vfmadd132sd => (true, OpcodeMap::_0F38, 0x99),
22882297
AvxOpcode::Vfmadd213sd => (true, OpcodeMap::_0F38, 0xA9),
2298+
AvxOpcode::Vfnmadd132sd => (true, OpcodeMap::_0F38, 0x9D),
2299+
AvxOpcode::Vfnmadd213sd => (true, OpcodeMap::_0F38, 0xAD),
2300+
AvxOpcode::Vfmadd132ps => (false, OpcodeMap::_0F38, 0x98),
22892301
AvxOpcode::Vfmadd213ps => (false, OpcodeMap::_0F38, 0xA8),
2302+
AvxOpcode::Vfnmadd132ps => (false, OpcodeMap::_0F38, 0x9C),
2303+
AvxOpcode::Vfnmadd213ps => (false, OpcodeMap::_0F38, 0xAC),
2304+
AvxOpcode::Vfmadd132pd => (true, OpcodeMap::_0F38, 0x98),
22902305
AvxOpcode::Vfmadd213pd => (true, OpcodeMap::_0F38, 0xA8),
2306+
AvxOpcode::Vfnmadd132pd => (true, OpcodeMap::_0F38, 0x9C),
2307+
AvxOpcode::Vfnmadd213pd => (true, OpcodeMap::_0F38, 0xAC),
22912308
AvxOpcode::Vblendvps => (false, OpcodeMap::_0F3A, 0x4A),
22922309
AvxOpcode::Vblendvpd => (false, OpcodeMap::_0F3A, 0x4B),
22932310
AvxOpcode::Vpblendvb => (false, OpcodeMap::_0F3A, 0x4C),
22942311
_ => unreachable!(),
22952312
};
22962313

2297-
match src3 {
2298-
RegMem::Reg { reg: src } => VexInstruction::new()
2299-
.length(VexVectorLength::V128)
2300-
.prefix(LegacyPrefixes::_66)
2301-
.map(map)
2302-
.w(w)
2303-
.opcode(opcode)
2304-
.reg(dst.to_real_reg().unwrap().hw_enc())
2305-
.rm(src.to_real_reg().unwrap().hw_enc())
2306-
.vvvv(src2.to_real_reg().unwrap().hw_enc())
2307-
.encode(sink),
2308-
_ => todo!(),
2309-
};
2314+
VexInstruction::new()
2315+
.length(VexVectorLength::V128)
2316+
.prefix(LegacyPrefixes::_66)
2317+
.map(map)
2318+
.w(w)
2319+
.opcode(opcode)
2320+
.reg(dst.to_real_reg().unwrap().hw_enc())
2321+
.rm(src3)
2322+
.vvvv(src2.to_real_reg().unwrap().hw_enc())
2323+
.encode(sink);
23102324
}
23112325

23122326
Inst::XmmRmRBlendVex {

cranelift/codegen/src/isa/x64/inst/mod.rs

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,23 +1944,12 @@ fn x64_get_operands<F: Fn(VReg) -> VReg>(inst: &Inst, collector: &mut OperandCol
19441944
src2.get_operands(collector);
19451945
}
19461946
Inst::XmmRmRVex3 {
1947-
op,
19481947
src1,
19491948
src2,
19501949
src3,
19511950
dst,
19521951
..
19531952
} => {
1954-
// Vfmadd uses and defs the dst reg, that is not the case with all
1955-
// AVX's ops, if you're adding a new op, make sure to correctly define
1956-
// register uses.
1957-
assert!(
1958-
*op == AvxOpcode::Vfmadd213ss
1959-
|| *op == AvxOpcode::Vfmadd213sd
1960-
|| *op == AvxOpcode::Vfmadd213ps
1961-
|| *op == AvxOpcode::Vfmadd213pd
1962-
);
1963-
19641953
collector.reg_use(src1.to_reg());
19651954
collector.reg_reuse_def(dst.to_writable_reg(), 0);
19661955
collector.reg_use(src2.to_reg());

cranelift/codegen/src/isa/x64/lower.isle

Lines changed: 81 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,13 +2167,13 @@
21672167
;; The above rules automatically sink loads for rhs operands, so additionally
21682168
;; add rules for sinking loads with lhs operands.
21692169
(rule 1 (lower (has_type $F32 (fadd (sinkable_load x) y)))
2170-
(x64_addss y (sink_load x)))
2170+
(x64_addss y x))
21712171
(rule 1 (lower (has_type $F64 (fadd (sinkable_load x) y)))
2172-
(x64_addsd y (sink_load x)))
2172+
(x64_addsd y x))
21732173
(rule 1 (lower (has_type $F32X4 (fadd (sinkable_load x) y)))
2174-
(x64_addps y (sink_load x)))
2174+
(x64_addps y x))
21752175
(rule 1 (lower (has_type $F64X2 (fadd (sinkable_load x) y)))
2176-
(x64_addpd y (sink_load x)))
2176+
(x64_addpd y x))
21772177

21782178
;; Rules for `fsub` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
21792179

@@ -2200,13 +2200,13 @@
22002200
;; The above rules automatically sink loads for rhs operands, so additionally
22012201
;; add rules for sinking loads with lhs operands.
22022202
(rule 1 (lower (has_type $F32 (fmul (sinkable_load x) y)))
2203-
(x64_mulss y (sink_load x)))
2203+
(x64_mulss y x))
22042204
(rule 1 (lower (has_type $F64 (fmul (sinkable_load x) y)))
2205-
(x64_mulsd y (sink_load x)))
2205+
(x64_mulsd y x))
22062206
(rule 1 (lower (has_type $F32X4 (fmul (sinkable_load x) y)))
2207-
(x64_mulps y (sink_load x)))
2207+
(x64_mulps y x))
22082208
(rule 1 (lower (has_type $F64X2 (fmul (sinkable_load x) y)))
2209-
(x64_mulpd y (sink_load x)))
2209+
(x64_mulpd y x))
22102210

22112211
;; Rules for `fdiv` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
22122212

@@ -2438,18 +2438,83 @@
24382438

24392439
;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
24402440

2441+
;; Base case for fma is to call out to one of two libcalls. For vectors they
2442+
;; need to be decomposed, handle each element individually, and then recomposed.
2443+
24412444
(rule (lower (has_type $F32 (fma x y z)))
24422445
(libcall_3 (LibCall.FmaF32) x y z))
24432446
(rule (lower (has_type $F64 (fma x y z)))
24442447
(libcall_3 (LibCall.FmaF64) x y z))
2445-
(rule 1 (lower (has_type (and (use_fma $true) $F32) (fma x y z)))
2446-
(x64_vfmadd213ss x y z))
2447-
(rule 1 (lower (has_type (and (use_fma $true) $F64) (fma x y z)))
2448-
(x64_vfmadd213sd x y z))
2449-
(rule (lower (has_type (and (use_fma $true) $F32X4) (fma x y z)))
2450-
(x64_vfmadd213ps x y z))
2451-
(rule (lower (has_type (and (use_fma $true) $F64X2) (fma x y z)))
2452-
(x64_vfmadd213pd x y z))
2448+
2449+
(rule (lower (has_type $F32X4 (fma x y z)))
2450+
(let (
2451+
(x Xmm (put_in_xmm x))
2452+
(y Xmm (put_in_xmm y))
2453+
(z Xmm (put_in_xmm z))
2454+
(x0 Xmm (libcall_3 (LibCall.FmaF32) x y z))
2455+
(x1 Xmm (libcall_3 (LibCall.FmaF32)
2456+
(x64_pshufd x 1)
2457+
(x64_pshufd y 1)
2458+
(x64_pshufd z 1)))
2459+
(x2 Xmm (libcall_3 (LibCall.FmaF32)
2460+
(x64_pshufd x 2)
2461+
(x64_pshufd y 2)
2462+
(x64_pshufd z 2)))
2463+
(x3 Xmm (libcall_3 (LibCall.FmaF32)
2464+
(x64_pshufd x 3)
2465+
(x64_pshufd y 3)
2466+
(x64_pshufd z 3)))
2467+
2468+
(tmp Xmm (vec_insert_lane $F32X4 x0 x1 1))
2469+
(tmp Xmm (vec_insert_lane $F32X4 tmp x2 2))
2470+
(tmp Xmm (vec_insert_lane $F32X4 tmp x3 3))
2471+
)
2472+
tmp))
2473+
(rule (lower (has_type $F64X2 (fma x y z)))
2474+
(let (
2475+
(x Xmm (put_in_xmm x))
2476+
(y Xmm (put_in_xmm y))
2477+
(z Xmm (put_in_xmm z))
2478+
(x0 Xmm (libcall_3 (LibCall.FmaF64) x y z))
2479+
(x1 Xmm (libcall_3 (LibCall.FmaF64)
2480+
(x64_pshufd x 0xee)
2481+
(x64_pshufd y 0xee)
2482+
(x64_pshufd z 0xee)))
2483+
)
2484+
(vec_insert_lane $F64X2 x0 x1 1)))
2485+
2486+
2487+
;; Special case for when the `fma` feature is active and a native instruction
2488+
;; can be used.
2489+
(rule 1 (lower (has_type ty (fma x y z)))
2490+
(if-let $true (use_fma))
2491+
(fmadd ty x y z))
2492+
2493+
(decl fmadd (Type Value Value Value) Xmm)
2494+
(decl fnmadd (Type Value Value Value) Xmm)
2495+
2496+
;; Base case. Note that this will automatically sink a load with `z`, the value
2497+
;; to add.
2498+
(rule (fmadd ty x y z) (x64_vfmadd213 ty x y z))
2499+
2500+
;; Allow sinking loads with one of the two values being multiplied in addition
2501+
;; to the value being added. Note that both x and y can be sunk here due to
2502+
;; multiplication being commutative.
2503+
(rule 1 (fmadd ty (sinkable_load x) y z) (x64_vfmadd132 ty y z x))
2504+
(rule 2 (fmadd ty x (sinkable_load y) z) (x64_vfmadd132 ty x z y))
2505+
2506+
;; If one of the values being multiplied is negated then use a `vfnmadd*`
2507+
;; instruction instead
2508+
(rule 3 (fmadd ty (fneg x) y z) (fnmadd ty x y z))
2509+
(rule 4 (fmadd ty x (fneg y) z) (fnmadd ty x y z))
2510+
2511+
(rule (fnmadd ty x y z) (x64_vfnmadd213 ty x y z))
2512+
(rule 1 (fnmadd ty (sinkable_load x) y z) (x64_vfnmadd132 ty y z x))
2513+
(rule 2 (fnmadd ty x (sinkable_load y) z) (x64_vfnmadd132 ty x z y))
2514+
2515+
;; Like `fmadd` if one argument is negated switch which one is being codegen'd
2516+
(rule 3 (fnmadd ty (fneg x) y z) (fmadd ty x y z))
2517+
(rule 4 (fnmadd ty x (fneg y) z) (fmadd ty x y z))
24532518

24542519
;; Rules for `load*` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
24552520

cranelift/codegen/src/isa/x64/lower/isle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ impl Context for IsleContext<'_, '_, MInst, X64Backend> {
213213
}
214214

215215
#[inline]
216-
fn use_fma(&mut self, _: Type) -> bool {
216+
fn use_fma(&mut self) -> bool {
217217
self.backend.x64_flags.use_fma()
218218
}
219219

0 commit comments

Comments
 (0)