Skip to content

Commit 6290543

Browse files
committed
Remove redundant arguments from x86 codegen funcs
1 parent 3cf8fa3 commit 6290543

File tree

2 files changed

+94
-104
lines changed

2 files changed

+94
-104
lines changed

fearless_simd_gen/src/mk_avx2.rs

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ fn mk_simd_impl() -> TokenStream {
8686
continue;
8787
}
8888

89-
let method = make_method(method, sig, vec_ty, vec_ty.n_bits());
89+
let method = make_method(method, sig, vec_ty);
9090

9191
methods.push(method);
9292
}
@@ -158,7 +158,7 @@ fn mk_type_impl() -> TokenStream {
158158
}
159159
}
160160

161-
fn make_method(method: &str, sig: OpSig, vec_ty: &VecType, ty_bits: usize) -> TokenStream {
161+
fn make_method(method: &str, sig: OpSig, vec_ty: &VecType) -> TokenStream {
162162
let scalar_bits = vec_ty.scalar_bits;
163163
let ty_name = vec_ty.rust_name();
164164
let method_name = format!("{method}_{ty_name}");
@@ -175,14 +175,12 @@ fn make_method(method: &str, sig: OpSig, vec_ty: &VecType, ty_bits: usize) -> To
175175
}
176176

177177
match sig {
178-
OpSig::Splat => mk_sse4_2::handle_splat(method_sig, vec_ty, scalar_bits, ty_bits),
179-
OpSig::Compare => handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits),
178+
OpSig::Splat => mk_sse4_2::handle_splat(method_sig, vec_ty),
179+
OpSig::Compare => handle_compare(method_sig, method, vec_ty),
180180
OpSig::Unary => mk_sse4_2::handle_unary(method_sig, method, vec_ty),
181-
OpSig::WidenNarrow(t) => {
182-
handle_widen_narrow(method_sig, method, vec_ty, scalar_bits, ty_bits, t)
183-
}
181+
OpSig::WidenNarrow(t) => handle_widen_narrow(method_sig, method, vec_ty, t),
184182
OpSig::Binary => mk_sse4_2::handle_binary(method_sig, method, vec_ty),
185-
OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty, scalar_bits, ty_bits),
183+
OpSig::Shift => mk_sse4_2::handle_shift(method_sig, method, vec_ty),
186184
OpSig::Ternary => match method {
187185
"madd" => {
188186
let intrinsic =
@@ -204,15 +202,13 @@ fn make_method(method: &str, sig: OpSig, vec_ty: &VecType, ty_bits: usize) -> To
204202
}
205203
_ => mk_sse4_2::handle_ternary(method_sig, &method_ident, method, vec_ty),
206204
},
207-
OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty, scalar_bits),
205+
OpSig::Select => mk_sse4_2::handle_select(method_sig, vec_ty),
208206
OpSig::Combine => handle_combine(method_sig, vec_ty),
209207
OpSig::Split => handle_split(method_sig, vec_ty),
210-
OpSig::Zip(zip1) => mk_sse4_2::handle_zip(method_sig, vec_ty, scalar_bits, zip1),
211-
OpSig::Unzip(select_even) => {
212-
mk_sse4_2::handle_unzip(method_sig, vec_ty, scalar_bits, select_even)
213-
}
208+
OpSig::Zip(zip1) => mk_sse4_2::handle_zip(method_sig, vec_ty, zip1),
209+
OpSig::Unzip(select_even) => mk_sse4_2::handle_unzip(method_sig, vec_ty, select_even),
214210
OpSig::Cvt(scalar, target_scalar_bits) => {
215-
mk_sse4_2::handle_cvt(method_sig, vec_ty, ty_bits, scalar, target_scalar_bits)
211+
mk_sse4_2::handle_cvt(method_sig, vec_ty, scalar, target_scalar_bits)
216212
}
217213
OpSig::Reinterpret(scalar, target_scalar_bits) => {
218214
mk_sse4_2::handle_reinterpret(method_sig, vec_ty, scalar, target_scalar_bits)
@@ -272,8 +268,6 @@ pub(crate) fn handle_compare(
272268
method_sig: TokenStream,
273269
method: &str,
274270
vec_ty: &VecType,
275-
scalar_bits: usize,
276-
ty_bits: usize,
277271
) -> TokenStream {
278272
if vec_ty.scalar == ScalarType::Float {
279273
// For AVX2 and up, Intel gives us a generic comparison intrinsic that takes a predicate. There are 32,
@@ -288,48 +282,59 @@ pub(crate) fn handle_compare(
288282
"simd_gt" => 0x1E,
289283
_ => unreachable!(),
290284
};
291-
let intrinsic = simple_intrinsic("cmp", vec_ty.scalar, scalar_bits, ty_bits);
292-
let cast = cast_ident(ScalarType::Float, ScalarType::Mask, scalar_bits, ty_bits);
285+
let intrinsic = simple_intrinsic("cmp", vec_ty.scalar, vec_ty.scalar_bits, vec_ty.n_bits());
286+
let cast = cast_ident(
287+
ScalarType::Float,
288+
ScalarType::Mask,
289+
vec_ty.scalar_bits,
290+
vec_ty.n_bits(),
291+
);
293292

294293
quote! {
295294
#method_sig {
296295
unsafe { #cast(#intrinsic::<#order_predicate>(a.into(), b.into())).simd_into(self) }
297296
}
298297
}
299298
} else {
300-
mk_sse4_2::handle_compare(method_sig, method, vec_ty, scalar_bits, ty_bits)
299+
mk_sse4_2::handle_compare(method_sig, method, vec_ty)
301300
}
302301
}
303302

304303
pub(crate) fn handle_widen_narrow(
305304
method_sig: TokenStream,
306305
method: &str,
307306
vec_ty: &VecType,
308-
scalar_bits: usize,
309-
ty_bits: usize,
310307
t: VecType,
311308
) -> TokenStream {
312309
let expr = match method {
313310
"widen" => {
314311
let dst_width = t.n_bits();
315-
match (dst_width, ty_bits) {
312+
match (dst_width, vec_ty.n_bits()) {
316313
(256, 128) => {
317-
let extend =
318-
extend_intrinsic(vec_ty.scalar, scalar_bits, t.scalar_bits, dst_width);
314+
let extend = extend_intrinsic(
315+
vec_ty.scalar,
316+
vec_ty.scalar_bits,
317+
t.scalar_bits,
318+
dst_width,
319+
);
319320
quote! {
320321
unsafe {
321322
#extend(a.into()).simd_into(self)
322323
}
323324
}
324325
}
325326
(512, 256) => {
326-
let extend =
327-
extend_intrinsic(vec_ty.scalar, scalar_bits, t.scalar_bits, ty_bits);
327+
let extend = extend_intrinsic(
328+
vec_ty.scalar,
329+
vec_ty.scalar_bits,
330+
t.scalar_bits,
331+
vec_ty.n_bits(),
332+
);
328333
let combine = format_ident!(
329334
"combine_{}",
330335
VecType {
331336
len: vec_ty.len / 2,
332-
scalar_bits: scalar_bits * 2,
337+
scalar_bits: vec_ty.scalar_bits * 2,
333338
..*vec_ty
334339
}
335340
.rust_name()
@@ -349,7 +354,7 @@ pub(crate) fn handle_widen_narrow(
349354
}
350355
"narrow" => {
351356
let dst_width = t.n_bits();
352-
match (dst_width, ty_bits) {
357+
match (dst_width, vec_ty.n_bits()) {
353358
(128, 256) => {
354359
let mask = match t.scalar_bits {
355360
8 => {
@@ -369,9 +374,9 @@ pub(crate) fn handle_widen_narrow(
369374
}
370375
}
371376
(256, 512) => {
372-
let mask = set1_intrinsic(vec_ty.scalar, scalar_bits, t.n_bits());
377+
let mask = set1_intrinsic(vec_ty.scalar, vec_ty.scalar_bits, t.n_bits());
373378
let pack = pack_intrinsic(
374-
scalar_bits,
379+
vec_ty.scalar_bits,
375380
matches!(vec_ty.scalar, ScalarType::Int),
376381
t.n_bits(),
377382
);

0 commit comments

Comments
 (0)