Skip to content

Commit f0d997e

Browse files
authored
chore: fine tune binomial_bounds.rs code (#103)
Signed-off-by: tison <wander4096@gmail.com>
1 parent c41fe90 commit f0d997e

File tree

1 file changed

+25
-49
lines changed

1 file changed

+25
-49
lines changed

datasketches/src/common/binomial_bounds.rs

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ static UB_EQUIV_TABLE: [f64; 363] = [
275275
/// # Arguments
276276
///
277277
/// * `num_samples`: The number of samples in the sample set.
278-
/// * `theta`: The sampling probability. Must be in the range (0.0, 1.0].
278+
/// * `theta`: The sampling probability. Must be in the range `(0.0, 1.0]`.
279279
/// * `num_std_dev`: The number of standard deviations for confidence bounds.
280280
///
281281
/// # Returns
@@ -284,13 +284,17 @@ static UB_EQUIV_TABLE: [f64; 363] = [
284284
///
285285
/// # Errors
286286
///
287-
/// Returns an error if `theta` is not in the range (0.0, 1.0].
287+
/// Returns an error if `theta` is not in the range `(0.0, 1.0]`.
288288
pub(crate) fn lower_bound(
289289
num_samples: u64,
290290
theta: f64,
291291
num_std_dev: NumStdDev,
292292
) -> Result<f64, Error> {
293-
check_theta(theta)?;
293+
if theta <= 0.0 || theta > 1.0 {
294+
return Err(Error::invalid_argument(format!(
295+
"theta must be in the range (0.0, 1.0], got {theta}"
296+
)));
297+
}
294298

295299
let estimate = num_samples as f64 / theta;
296300
let lb = compute_approx_binomial_lower_bound(num_samples, theta, num_std_dev);
@@ -325,7 +329,12 @@ pub(crate) fn upper_bound(
325329
if no_data_seen {
326330
return Ok(0.0);
327331
}
328-
check_theta(theta)?;
332+
333+
if theta <= 0.0 || theta > 1.0 {
334+
return Err(Error::invalid_argument(format!(
335+
"theta must be in the range (0.0, 1.0], got {theta}"
336+
)));
337+
}
329338

330339
let estimate = num_samples as f64 / theta;
331340
let ub = compute_approx_binomial_upper_bound(num_samples, theta, num_std_dev);
@@ -360,6 +369,7 @@ fn cont_classic_ub(num_samples: u64, theta: f64, num_std_devs: f64) -> f64 {
360369
/// # Limitations
361370
///
362371
/// Outside of the valid input range, two different bad things will happen:
372+
///
363373
/// 1. Because we are not using logarithms, the values of intermediate quantities will exceed the
364374
/// dynamic range of doubles.
365375
/// 2. Even if that problem were fixed, the running time of this procedure is essentially linear in
@@ -548,21 +558,6 @@ fn compute_approx_binomial_upper_bound(
548558
special_n_prime_f(num_samples, theta, delta).unwrap_or(num_samples + 1) as f64 // no need to round
549559
}
550560

551-
/// Validates that theta is in the valid range [0.0, 1.0].
552-
///
553-
/// # Errors
554-
///
555-
/// Returns an error if theta < 0.0 or theta > 1.0.
556-
fn check_theta(theta: f64) -> Result<(), Error> {
557-
if (theta <= 0.0) || (theta > 1.0) {
558-
return Err(Error::invalid_argument(format!(
559-
"theta must be in the range [0.0, 1.0]: {}",
560-
theta
561-
)));
562-
}
563-
Ok(())
564-
}
565-
566561
#[cfg(test)]
567562
mod tests {
568563
use super::*;
@@ -679,53 +674,34 @@ mod tests {
679674
fn check_bounds() {
680675
let mut i = 0;
681676

677+
fn assert_approx_equal(ci: NumStdDev, j: usize, expected: f64, actual: f64) {
678+
let ratio = actual / expected;
679+
assert!(
680+
(ratio - 1.0).abs() < TOL,
681+
"ci={ci:?}, j={j}: expected {expected}, got {actual}, ratio={ratio}",
682+
);
683+
}
684+
682685
for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
683686
let arr = run_test_aux(20, ci, 1e-3);
684687
for j in 0..5 {
685-
let ratio = arr[j] / STD[i][j];
686-
assert!(
687-
(ratio - 1.0).abs() < TOL,
688-
"ci={:?}, j={}: expected {}, got {}, ratio={}",
689-
ci,
690-
j,
691-
STD[i][j],
692-
arr[j],
693-
ratio
694-
);
688+
assert_approx_equal(ci, j, STD[i][j], arr[j]);
695689
}
696690
i += 1;
697691
}
698692

699693
for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
700694
let arr = run_test_aux(200, ci, 1e-5);
701695
for j in 0..5 {
702-
let ratio = arr[j] / STD[i][j];
703-
assert!(
704-
(ratio - 1.0) < TOL,
705-
"ci={:?}, j={}: expected {}, got {}, ratio={}",
706-
ci,
707-
j,
708-
STD[i][j],
709-
arr[j],
710-
ratio
711-
);
696+
assert_approx_equal(ci, j, STD[i][j], arr[j]);
712697
}
713698
i += 1;
714699
}
715700

716701
for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
717702
let arr = run_test_aux(2000, ci, 1e-7);
718703
for j in 0..5 {
719-
let ratio = arr[j] / STD[i][j];
720-
assert!(
721-
(ratio - 1.0).abs() < TOL,
722-
"ci={:?}, j={}: expected {}, got {}, ratio={}",
723-
ci,
724-
j,
725-
STD[i][j],
726-
arr[j],
727-
ratio
728-
);
704+
assert_approx_equal(ci, j, STD[i][j], arr[j]);
729705
}
730706
i += 1;
731707
}

0 commit comments

Comments
 (0)