Skip to content

Commit a6aaf41

Browse files
committed
fix: resolve AVX-512 build errors (duplicate function and type mismatch)
1 parent 5decf4c commit a6aaf41

File tree

1 file changed

+27
-146
lines changed

1 file changed

+27
-146
lines changed

src/simd.rs

Lines changed: 27 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -65,29 +65,32 @@ cfg_if! {
6565
}
6666

6767
/// SIMD-accelerated conversion from bytes to MB (AVX-512).
68-
/// Uses the powerful _mm512_cvtepu64_ps to convert 8 u64s to f32s in one go.
6968
pub fn convert_bytes_to_mb_simd(bytes: &[usize], mb: &mut [f32]) {
7069
let len = bytes.len();
7170
let mut i = 0;
7271
const DIV: f32 = 1.0 / (1024.0 * 1024.0);
7372

7473
unsafe {
75-
while i + 8 <= len {
76-
// Load 8 u64 values into a 512-bit register
77-
let v_u64 = _mm512_loadu_si512(bytes.as_ptr().add(i) as *const _);
74+
let div_v = _mm512_set1_ps(DIV);
75+
while i + 16 <= len {
76+
// Load 16 u64 values (requires two 512-bit registers)
77+
let v_u64_1 = _mm512_loadu_si512(bytes.as_ptr().add(i) as *const _);
78+
let v_u64_2 = _mm512_loadu_si512(bytes.as_ptr().add(i + 8) as *const _);
7879

79-
// AVX-512 specific: Direct conversion from 8 unsigned 64-bit ints to 8 floats
80-
let v_f32 = _mm512_cvtepu64_ps(v_u64);
80+
// Convert 8 u64s to 8 f32s (each produces __m256)
81+
let v_f32_1 = _mm512_cvtepu64_ps(v_u64_1);
82+
let v_f32_2 = _mm512_cvtepu64_ps(v_u64_2);
83+
84+
// Combine two __m256 into one __m512
85+
let v_f512 = _mm512_insertf32x8(_mm512_castps256_ps512(v_f32_1), v_f32_2, 1);
8186

8287
// Multiply by reciprocal of 1MB
83-
// Note: _mm512_cvtepu64_ps returns __m256 (8 floats), so we use AVX for multiplication.
84-
let div_v_f256 = _mm256_set1_ps(DIV);
85-
let res = _mm256_mul_ps(v_f32, div_v_f256);
88+
let res = _mm512_mul_ps(v_f512, div_v);
8689

87-
// Store 8 floats (256 bits of data)
88-
_mm256_storeu_ps(mb.as_mut_ptr().add(i), res);
90+
// Store 16 floats (512 bits)
91+
_mm512_storeu_ps(mb.as_mut_ptr().add(i), res);
8992

90-
i += 8;
93+
i += 16;
9194
}
9295
}
9396

@@ -97,137 +100,6 @@ cfg_if! {
97100
}
98101

99102

100-
/// SIMD-accelerated total multiplier calculation (AVX2).
101-
pub fn calculate_total_multipliers_simd(
102-
priorities: &[u32],
103-
elevations: &[bool],
104-
games: &[bool],
105-
foregrounds: &[bool],
106-
windows: &[bool],
107-
uptimes: &[u64],
108-
mults: &mut [f32]
109-
) {
110-
let len = priorities.len();
111-
let mut i = 0;
112-
113-
unsafe {
114-
let v_one = _mm256_set1_ps(1.0);
115-
let v_half = _mm256_set1_ps(0.5);
116-
let v_small = _mm256_set1_ps(0.01);
117-
let v_zero_i = _mm256_setzero_si256();
118-
119-
// Priority constants
120-
let p_idle = _mm256_set1_epi32(0x40);
121-
let p_high = _mm256_set1_epi32(0x80);
122-
let p_real = _mm256_set1_epi32(0x100);
123-
let p_below = _mm256_set1_epi32(0x4000);
124-
let p_above = _mm256_set1_epi32(0x8000);
125-
126-
let m_idle = _mm256_set1_ps(2.0);
127-
let m_high = _mm256_set1_ps(0.2);
128-
let m_real = _mm256_set1_ps(0.01);
129-
let m_below = _mm256_set1_ps(1.5);
130-
let m_above = _mm256_set1_ps(0.5);
131-
132-
// Uptime constants
133-
let u_div = _mm256_set1_ps(1.0 / 3600.0 * 0.1);
134-
let u_min = _mm256_set1_ps(0.7);
135-
136-
while i + 8 <= len {
137-
// 1. Priority Multiplier
138-
let v_prio = _mm256_loadu_si256(priorities.as_ptr().add(i) as *const _);
139-
let mut v_p_mult = v_one;
140-
141-
// Blend based on matches (cmpeq returns integer mask, cast to float for blendv_ps)
142-
let mask_idle = _mm256_castsi256_ps(_mm256_cmpeq_epi32(v_prio, p_idle));
143-
v_p_mult = _mm256_blendv_ps(v_p_mult, m_idle, mask_idle);
144-
145-
let mask_high = _mm256_castsi256_ps(_mm256_cmpeq_epi32(v_prio, p_high));
146-
v_p_mult = _mm256_blendv_ps(v_p_mult, m_high, mask_high);
147-
148-
let mask_real = _mm256_castsi256_ps(_mm256_cmpeq_epi32(v_prio, p_real));
149-
v_p_mult = _mm256_blendv_ps(v_p_mult, m_real, mask_real);
150-
151-
let mask_below = _mm256_castsi256_ps(_mm256_cmpeq_epi32(v_prio, p_below));
152-
v_p_mult = _mm256_blendv_ps(v_p_mult, m_below, mask_below);
153-
154-
let mask_above = _mm256_castsi256_ps(_mm256_cmpeq_epi32(v_prio, p_above));
155-
v_p_mult = _mm256_blendv_ps(v_p_mult, m_above, mask_above);
156-
157-
// 2. Boolean Multipliers
158-
// Load 8 bytes (lower half of XMM register)
159-
// Note: _mm_loadl_epi64 loads 64 bits.
160-
161-
// Elevation
162-
let v_elev_u64 = _mm_loadl_epi64(elevations.as_ptr().add(i) as *const _);
163-
let v_elev_i32 = _mm256_cvtepu8_epi32(v_elev_u64);
164-
let v_elev_f32 = _mm256_cvtepi32_ps(v_elev_i32);
165-
let v_e_mult = _mm256_sub_ps(v_one, _mm256_mul_ps(v_elev_f32, v_half));
166-
167-
// Game
168-
let v_game_u64 = _mm_loadl_epi64(games.as_ptr().add(i) as *const _);
169-
let v_game_i32 = _mm256_cvtepu8_epi32(v_game_u64);
170-
// mask: if val != 0
171-
let mask_game_i = _mm256_cmpeq_epi32(v_game_i32, v_zero_i); // 0xFFFF if 0 (false), 0 if 1 (true)
172-
// We want 1.0 if false, 0.01 if true.
173-
// blendv picks second arg if mask bit is 1.
174-
// if mask_game_i is all ones (false), we pick v_one.
175-
let v_g_mult = _mm256_blendv_ps(v_small, v_one, _mm256_castsi256_ps(mask_game_i));
176-
177-
// Foreground
178-
let v_fore_u64 = _mm_loadl_epi64(foregrounds.as_ptr().add(i) as *const _);
179-
let v_fore_i32 = _mm256_cvtepu8_epi32(v_fore_u64);
180-
let mask_fore_i = _mm256_cmpeq_epi32(v_fore_i32, v_zero_i);
181-
let v_f_mult = _mm256_blendv_ps(v_small, v_one, _mm256_castsi256_ps(mask_fore_i));
182-
183-
// Window
184-
let v_win_u64 = _mm_loadl_epi64(windows.as_ptr().add(i) as *const _);
185-
let v_win_i32 = _mm256_cvtepu8_epi32(v_win_u64);
186-
let mask_win_i = _mm256_cmpeq_epi32(v_win_i32, v_zero_i);
187-
// if false (0) -> 1.0. if true (1) -> 0.5.
188-
let v_w_mult = _mm256_blendv_ps(v_half, v_one, _mm256_castsi256_ps(mask_win_i));
189-
190-
// 3. Uptime
191-
let ptr = uptimes.as_ptr().add(i);
192-
let v_upt_f32 = _mm256_setr_ps(
193-
*ptr.add(0) as f32, *ptr.add(1) as f32, *ptr.add(2) as f32, *ptr.add(3) as f32,
194-
*ptr.add(4) as f32, *ptr.add(5) as f32, *ptr.add(6) as f32, *ptr.add(7) as f32,
195-
);
196-
let v_u_sub = _mm256_mul_ps(v_upt_f32, u_div);
197-
let v_u_mult = _mm256_max_ps(u_min, _mm256_sub_ps(v_one, v_u_sub));
198-
199-
// 4. Combine
200-
let mut total = v_p_mult;
201-
total = _mm256_mul_ps(total, v_e_mult);
202-
total = _mm256_mul_ps(total, v_g_mult);
203-
total = _mm256_mul_ps(total, v_f_mult);
204-
total = _mm256_mul_ps(total, v_w_mult);
205-
total = _mm256_mul_ps(total, v_u_mult);
206-
207-
_mm256_storeu_ps(mults.as_mut_ptr().add(i), total);
208-
i += 8;
209-
}
210-
}
211-
212-
// Scalar fallback
213-
use crate::scoring::{
214-
get_priority_multiplier, get_elevation_multiplier,
215-
get_game_multiplier, get_foreground_multiplier,
216-
get_window_multiplier, get_uptime_multiplier
217-
};
218-
use windows::Win32::System::Threading::PROCESS_CREATION_FLAGS;
219-
220-
for j in i..len {
221-
let p = get_priority_multiplier(PROCESS_CREATION_FLAGS(priorities[j]));
222-
let e = get_elevation_multiplier(elevations[j]);
223-
let g = get_game_multiplier(games[j]);
224-
let f = get_foreground_multiplier(foregrounds[j]);
225-
let w = get_window_multiplier(windows[j]);
226-
let u = get_uptime_multiplier(uptimes[j]);
227-
mults[j] = p * e * g * f * w * u;
228-
}
229-
}
230-
231103
/// SIMD-accelerated total multiplier calculation (AVX-512).
232104
/// Combines all multipliers (Priority, Elevation, Game, Foreground, Window, Uptime) into one.
233105
pub fn calculate_total_multipliers_simd(
@@ -314,9 +186,18 @@ cfg_if! {
314186

315187
// 3. Uptime Multiplier
316188
// (1.0 - (uptime / 3600.0 * 0.1)).max(0.7)
317-
let v_upt_u64 = _mm512_loadu_si512(uptimes.as_ptr().add(i) as *const _);
318-
let v_upt_f32 = _mm512_cvtepu64_ps(v_upt_u64);
319-
let v_u_sub = _mm512_mul_ps(v_upt_f32, u_div);
189+
// Load 16 u64s in two batches of 8
190+
let v_upt_u64_1 = _mm512_loadu_si512(uptimes.as_ptr().add(i) as *const _);
191+
let v_upt_u64_2 = _mm512_loadu_si512(uptimes.as_ptr().add(i + 8) as *const _);
192+
193+
// Convert 8 u64s to 8 f32s (each produces __m256)
194+
let v_upt_f32_1 = _mm512_cvtepu64_ps(v_upt_u64_1);
195+
let v_upt_f32_2 = _mm512_cvtepu64_ps(v_upt_u64_2);
196+
197+
// Combine into one __m512 (16 floats)
198+
let v_upt_f512 = _mm512_insertf32x8(_mm512_castps256_ps512(v_upt_f32_1), v_upt_f32_2, 1);
199+
200+
let v_u_sub = _mm512_mul_ps(v_upt_f512, u_div);
320201
let v_u_mult = _mm512_max_ps(u_min, _mm512_sub_ps(v_one, v_u_sub));
321202

322203
// 4. Combine all

0 commit comments

Comments
 (0)