Skip to content

Commit 59c352d

Browse files
committed
fix: update sequential_add_mul (allow for minor precision error)
1 parent cd68c99 commit 59c352d

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

downsample_rs/src/m4.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@ mod tests {
447447
let idxs3 = m4_without_x_parallel(arr.as_slice(), n_out);
448448
let idxs4 = m4_with_x_parallel(&x, arr.as_slice(), n_out);
449449
assert_eq!(idxs1, idxs3);
450+
// TODO: check whether this still fails after fixing the sequential_add_mul
450451
assert_eq!(idxs1, idxs4); // TODO: this fails when nb. of threads = 16
451452
}
452453
}

downsample_rs/src/searchsorted.rs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ fn sequential_add_mul(start_val: f64, add_val: f64, mul: usize) -> f64 {
122122
// larger than the largest positive f64 number.
123123
// This code should not fail when: (f64::MAX - start_val) < (add_val * mul).
124124
// -> Note that f64::MAX - start_val can be up to 2 * f64::MAX.
125-
let mul_2: usize = mul / 2;
126-
start_val + add_val * mul_2 as f64 + add_val * (mul - mul_2) as f64
125+
let mul_2: f64 = mul as f64 / 2.0;
126+
// start_val + add_val * mul_2 as f64 + add_val * (mul - mul_2) as f64
127+
start_val + add_val * mul_2 + add_val * mul_2
127128
}
128129

129130
pub(crate) fn get_equidistant_bin_idx_iterator_parallel<T>(
@@ -139,10 +140,12 @@ where
139140
let val_step: f64 =
140141
(arr[arr.len() - 1].as_() / nb_bins as f64) - (arr[0].as_() / nb_bins as f64);
141142
let arr0: f64 = arr[0].as_(); // The first value of the array
142-
// 2. Compute the number of threads & bins per thread
143+
144+
// 2. Compute the number of threads & bins per thread
143145
let n_threads = std::cmp::min(POOL.current_num_threads(), nb_bins);
144146
let nb_bins_per_thread = nb_bins / n_threads;
145147
let nb_bins_last_thread = nb_bins - nb_bins_per_thread * (n_threads - 1);
148+
146149
// 3. Iterate over the number of threads
147150
// -> for each thread perform the binary search sorted with moving left and
148151
// yield the indices (using the same idea as for the sequential version)
@@ -198,6 +201,20 @@ mod tests {
198201
#[case(101)]
199202
fn nb_bins(#[case] nb_bins: usize) {}
200203

204+
#[test]
205+
fn test_sequential_add_mul() {
206+
assert_eq!(sequential_add_mul(0.0, 1.0, 0), 0.0);
207+
assert_eq!(sequential_add_mul(-1.0, 1.0, 1), 0.0);
208+
// Really large values
209+
assert_eq!(sequential_add_mul(0.0, 1.0, 1_000_000), 1_000_000.0);
210+
// TODO: the next tests fails due to very minor precision error
211+
// -> however, this precision error is needed to avoid the issue with m4_with_x
212+
// assert_eq!(
213+
// sequential_add_mul(f64::MIN, f64::MAX / 2.0, 3),
214+
// f64::MIN + f64::MAX / 2.0 + f64::MAX
215+
// );
216+
}
217+
201218
#[test]
202219
fn test_search_sorted_identicial_to_np_linspace_searchsorted() {
203220
// Create a 0..9999 array

0 commit comments

Comments
 (0)