diff --git a/Cargo.lock b/Cargo.lock index f6b6e79..9fb1297 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,12 +52,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "bytemuck" -version = "1.23.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c76a5792e44e4abe34d3abf15636779261d45a7450612059293d1d2cfc63422" - [[package]] name = "clap" version = "4.5.40" @@ -108,7 +102,6 @@ checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" name = "fearless_simd" version = "0.3.0" dependencies = [ - "bytemuck", "libm", ] diff --git a/fearless_simd/Cargo.toml b/fearless_simd/Cargo.toml index 33f82c6..9cc5e96 100644 --- a/fearless_simd/Cargo.toml +++ b/fearless_simd/Cargo.toml @@ -38,5 +38,4 @@ force_support_fallback = [] workspace = true [dependencies] -bytemuck = "1.23.0" libm = { version = "0.2.15", optional = true } diff --git a/fearless_simd/src/generated/avx2.rs b/fearless_simd/src/generated/avx2.rs index 853a47e..d3e47b2 100644 --- a/fearless_simd/src/generated/avx2.rs +++ b/fearless_simd/src/generated/avx2.rs @@ -179,31 +179,19 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f64_f32x4(self, a: f32x4) -> f64x2 { - f64x2 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_pd(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_i32_f32x4(self, a: f32x4) -> i32x4 { - i32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_f32x4(self, a: f32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u32_f32x4(self, a: f32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn cvt_u32_f32x4(self, a: f32x4) -> u32x4 { @@ -352,17 +340,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i8x16(self, a: i8x16) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i8x16(self, a: i8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_u8x16(self, val: u8) -> u8x16 { @@ -511,10 +493,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u32_u8x16(self, a: u8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_mask8x16(self, val: i8) -> mask8x16 { @@ -665,17 +644,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i16x8(self, a: i16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i16x8(self, a: i16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_u16x8(self, val: u16) -> u16x8 { @@ -795,17 +768,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u16x8(self, a: u16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_u16x8(self, a: u16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_mask16x8(self, val: i16) -> mask16x8 { @@ -954,17 +921,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i32x4(self, a: i32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i32x4(self, a: i32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_i32x4(self, a: i32x4) -> f32x4 { @@ -1086,10 +1047,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u32x4(self, a: u32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_u32x4(self, a: u32x4) -> f32x4 { @@ -1253,10 +1211,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f32_f64x2(self, a: f64x2) -> f32x4 { - f32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castpd_ps(a.into()).simd_into(self) } } #[inline(always)] fn splat_mask64x2(self, val: i64) -> mask64x2 { @@ -1450,31 +1405,19 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f64_f32x8(self, a: f32x8) -> f64x4 { - f64x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm256_castps_pd(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_i32_f32x8(self, a: f32x8) -> i32x8 { - i32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm256_castps_si256(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_f32x8(self, a: f32x8) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm256_castps_si256(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u32_f32x8(self, a: f32x8) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm256_castps_si256(a.into()).simd_into(self) } } #[inline(always)] fn cvt_u32_f32x8(self, a: f32x8) -> u32x8 { @@ -1663,17 +1606,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i8x32(self, a: i8x32) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i8x32(self, a: i8x32) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn splat_u8x32(self, val: u8) -> u8x32 { @@ -1863,10 +1800,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u32_u8x32(self, a: u8x32) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn splat_mask8x32(self, val: i8) -> mask8x32 { @@ -2067,17 +2001,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i16x16(self, a: i16x16) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i16x16(self, a: i16x16) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn splat_u16x16(self, val: u16) -> u16x16 { @@ -2247,17 +2175,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u16x16(self, a: u16x16) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_u16x16(self, a: u16x16) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn splat_mask16x16(self, val: i16) -> mask16x16 { @@ -2446,17 +2368,11 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_i32x8(self, a: i32x8) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i32x8(self, a: i32x8) -> u32x8 { - u32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_i32x8(self, a: i32x8) -> f32x8 { @@ -2606,10 +2522,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_u8_u32x8(self, a: u32x8) -> u8x32 { - u8x32 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m256i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_u32x8(self, a: u32x8) -> f32x8 { @@ -2819,10 +2732,7 @@ impl Simd for Avx2 { } #[inline(always)] fn reinterpret_f32_f64x4(self, a: f64x4) -> f32x8 { - f32x8 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm256_castpd_ps(a.into()).simd_into(self) } } #[inline(always)] fn splat_mask64x4(self, val: i64) -> mask64x4 { diff --git a/fearless_simd/src/generated/fallback.rs b/fearless_simd/src/generated/fallback.rs index 8e7d1db..5bf310d 100644 --- a/fearless_simd/src/generated/fallback.rs +++ b/fearless_simd/src/generated/fallback.rs @@ -3,7 +3,7 @@ // This file is autogenerated by fearless_simd_gen -use crate::{Level, Simd, SimdInto, seal::Seal}; +use crate::{Bytes, Level, Simd, SimdInto, seal::Seal}; use crate::{ f32x4, f32x8, f32x16, f64x2, f64x4, f64x8, i8x16, i8x32, i8x64, i16x8, i16x16, i16x32, i32x4, i32x8, i32x16, mask8x16, mask8x32, mask8x64, mask16x8, mask16x16, mask16x32, mask32x4, @@ -339,31 +339,19 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_f64_f32x4(self, a: f32x4) -> f64x2 { - f64x2 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_i32_f32x4(self, a: f32x4) -> i32x4 { - i32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u8_f32x4(self, a: f32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u32_f32x4(self, a: f32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn cvt_u32_f32x4(self, a: f32x4) -> u32x4 { @@ -875,17 +863,11 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u8_i8x16(self, a: i8x16) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u32_i8x16(self, a: i8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn splat_u8x16(self, val: u8) -> u8x16 { @@ -1377,10 +1359,7 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u32_u8x16(self, a: u8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn splat_mask8x16(self, val: i8) -> mask8x16 { @@ -1861,17 +1840,11 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u8_i16x8(self, a: i16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u32_i16x8(self, a: i16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn splat_u16x8(self, val: u16) -> u16x8 { @@ -2166,17 +2139,11 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u8_u16x8(self, a: u16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u32_u16x8(self, a: u16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn splat_mask16x8(self, val: i16) -> mask16x8 { @@ -2497,17 +2464,11 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u8_i32x4(self, a: i32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn reinterpret_u32_i32x4(self, a: i32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn cvt_f32_i32x4(self, a: i32x4) -> f32x4 { @@ -2728,10 +2689,7 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_u8_u32x4(self, a: u32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn cvt_f32_u32x4(self, a: u32x4) -> f32x4 { @@ -3000,10 +2958,7 @@ impl Simd for Fallback { } #[inline(always)] fn reinterpret_f32_f64x2(self, a: f64x2) -> f32x4 { - f32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } #[inline(always)] fn splat_mask64x2(self, val: i64) -> mask64x2 { diff --git a/fearless_simd/src/generated/sse4_2.rs b/fearless_simd/src/generated/sse4_2.rs index b6d53c3..5e4038e 100644 --- a/fearless_simd/src/generated/sse4_2.rs +++ b/fearless_simd/src/generated/sse4_2.rs @@ -187,31 +187,19 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_f64_f32x4(self, a: f32x4) -> f64x2 { - f64x2 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_pd(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_i32_f32x4(self, a: f32x4) -> i32x4 { - i32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u8_f32x4(self, a: f32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn reinterpret_u32_f32x4(self, a: f32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castps_si128(a.into()).simd_into(self) } } #[inline(always)] fn cvt_u32_f32x4(self, a: f32x4) -> u32x4 { @@ -363,17 +351,11 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u8_i8x16(self, a: i8x16) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i8x16(self, a: i8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_u8x16(self, val: u8) -> u8x16 { @@ -530,10 +512,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u32_u8x16(self, a: u8x16) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_mask8x16(self, val: i8) -> mask8x16 { @@ -690,17 +669,11 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u8_i16x8(self, a: i16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i16x8(self, a: i16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_u16x8(self, val: u16) -> u16x8 { @@ -823,17 +796,11 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u8_u16x8(self, a: u16x8) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_u16x8(self, a: u16x8) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn splat_mask16x8(self, val: i16) -> mask16x8 { @@ -988,17 +955,11 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u8_i32x4(self, a: i32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn reinterpret_u32_i32x4(self, a: i32x4) -> u32x4 { - u32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_i32x4(self, a: i32x4) -> f32x4 { @@ -1123,10 +1084,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_u8_u32x4(self, a: u32x4) -> u8x16 { - u8x16 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + __m128i::from(a).simd_into(self) } #[inline(always)] fn cvt_f32_u32x4(self, a: u32x4) -> f32x4 { @@ -1296,10 +1254,7 @@ impl Simd for Sse4_2 { } #[inline(always)] fn reinterpret_f32_f64x2(self, a: f64x2) -> f32x4 { - f32x4 { - val: bytemuck::cast(a.val), - simd: a.simd, - } + unsafe { _mm_castpd_ps(a.into()).simd_into(self) } } #[inline(always)] fn splat_mask64x2(self, val: i64) -> mask64x2 { diff --git a/fearless_simd_gen/src/arch/x86.rs b/fearless_simd_gen/src/arch/x86.rs index 809a173..91664d8 100644 --- a/fearless_simd_gen/src/arch/x86.rs +++ b/fearless_simd_gen/src/arch/x86.rs @@ -232,7 +232,8 @@ pub(crate) fn intrinsic_ident(name: &str, suffix: &str, ty_bits: usize) -> Ident pub(crate) fn cast_ident( src_scalar_ty: ScalarType, dst_scalar_ty: ScalarType, - scalar_bits: usize, + src_scalar_bits: usize, + dst_scalar_bits: usize, ty_bits: usize, ) -> Ident { let prefix = match ty_bits { @@ -243,13 +244,13 @@ pub(crate) fn cast_ident( }; let src_name = coarse_type(&VecType::new( src_scalar_ty, - scalar_bits, - ty_bits / scalar_bits, + src_scalar_bits, + ty_bits / src_scalar_bits, )); let dst_name = coarse_type(&VecType::new( dst_scalar_ty, - scalar_bits, - ty_bits / scalar_bits, + dst_scalar_bits, + ty_bits / dst_scalar_bits, )); format_ident!("_mm{prefix}_cast{src_name}_{dst_name}") diff --git a/fearless_simd_gen/src/mk_avx2.rs b/fearless_simd_gen/src/mk_avx2.rs index 5e565d4..4f9e672 100644 --- a/fearless_simd_gen/src/mk_avx2.rs +++ b/fearless_simd_gen/src/mk_avx2.rs @@ -285,6 +285,7 @@ pub(crate) fn handle_compare( ScalarType::Float, ScalarType::Mask, vec_ty.scalar_bits, + vec_ty.scalar_bits, vec_ty.n_bits(), ); diff --git a/fearless_simd_gen/src/mk_fallback.rs b/fearless_simd_gen/src/mk_fallback.rs index 671ad82..14949fc 100644 --- a/fearless_simd_gen/src/mk_fallback.rs +++ b/fearless_simd_gen/src/mk_fallback.rs @@ -3,7 +3,7 @@ use crate::arch::fallback; use crate::generic::{generic_combine, generic_op, generic_split}; -use crate::ops::{OpSig, TyFlavor, ops_for_type, reinterpret_ty, valid_reinterpret}; +use crate::ops::{OpSig, TyFlavor, ops_for_type, valid_reinterpret}; use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; use proc_macro2::{Ident, Span, TokenStream}; use quote::quote; @@ -28,7 +28,7 @@ pub(crate) fn mk_fallback_impl() -> TokenStream { quote! { use core::ops::*; - use crate::{seal::Seal, Level, Simd, SimdInto}; + use crate::{Bytes, seal::Seal, Level, Simd, SimdInto}; #imports @@ -345,14 +345,9 @@ fn mk_simd_impl() -> TokenStream { } OpSig::Reinterpret(scalar, scalar_bits) => { if valid_reinterpret(vec_ty, scalar, scalar_bits) { - let to_ty = reinterpret_ty(vec_ty, scalar, scalar_bits).rust(); - quote! { #method_sig { - #to_ty { - val: bytemuck::cast(a.val), - simd: a.simd, - } + a.bitcast() } } } else { diff --git a/fearless_simd_gen/src/mk_sse4_2.rs b/fearless_simd_gen/src/mk_sse4_2.rs index e0bcd65..57a3965 100644 --- a/fearless_simd_gen/src/mk_sse4_2.rs +++ b/fearless_simd_gen/src/mk_sse4_2.rs @@ -7,7 +7,7 @@ use crate::arch::x86::{ unpack_intrinsic, }; use crate::generic::{generic_combine, generic_op, generic_split, scalar_binary}; -use crate::ops::{OpSig, TyFlavor, ops_for_type, reinterpret_ty, valid_reinterpret}; +use crate::ops::{OpSig, TyFlavor, ops_for_type, valid_reinterpret}; use crate::types::{SIMD_TYPES, ScalarType, VecType, type_imports}; use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote}; @@ -285,6 +285,7 @@ pub(crate) fn handle_compare( ScalarType::Float, ScalarType::Mask, vec_ty.scalar_bits, + vec_ty.scalar_bits, vec_ty.n_bits(), ); quote! { #ident(#expr) } @@ -536,6 +537,7 @@ pub(crate) fn handle_select(method_sig: TokenStream, vec_ty: &VecType) -> TokenS ScalarType::Mask, ScalarType::Float, vec_ty.scalar_bits, + vec_ty.scalar_bits, vec_ty.n_bits(), ); quote! { #ident(a.into()) } @@ -798,19 +800,34 @@ pub(crate) fn handle_reinterpret( scalar: ScalarType, scalar_bits: usize, ) -> TokenStream { - if valid_reinterpret(vec_ty, scalar, scalar_bits) { - let to_ty = reinterpret_ty(vec_ty, scalar, scalar_bits).rust(); + let dst_ty = VecType::new(scalar, scalar_bits, vec_ty.n_bits() / scalar_bits); + assert!( + valid_reinterpret(vec_ty, scalar, scalar_bits), + "{vec_ty:?} must be reinterpretable as {dst_ty:?}" + ); + if coarse_type(vec_ty) == coarse_type(&dst_ty) { + let arch_ty = x86::arch_ty(vec_ty); quote! { #method_sig { - #to_ty { - val: bytemuck::cast(a.val), - simd: a.simd, - } + #arch_ty::from(a).simd_into(self) } } } else { - quote! {} + let ident = cast_ident( + vec_ty.scalar, + scalar, + vec_ty.scalar_bits, + scalar_bits, + vec_ty.n_bits(), + ); + quote! { + #method_sig { + unsafe { + #ident(a.into()).simd_into(self) + } + } + } } }