@@ -275,7 +275,7 @@ static UB_EQUIV_TABLE: [f64; 363] = [
275275/// # Arguments
276276///
277277/// * `num_samples`: The number of samples in the sample set.
278- /// * `theta`: The sampling probability. Must be in the range (0.0, 1.0].
278+ /// * `theta`: The sampling probability. Must be in the range ` (0.0, 1.0]` .
279279/// * `num_std_dev`: The number of standard deviations for confidence bounds.
280280///
281281/// # Returns
@@ -284,13 +284,17 @@ static UB_EQUIV_TABLE: [f64; 363] = [
284284///
285285/// # Errors
286286///
287- /// Returns an error if `theta` is not in the range (0.0, 1.0].
287+ /// Returns an error if `theta` is not in the range ` (0.0, 1.0]` .
288288pub ( crate ) fn lower_bound (
289289 num_samples : u64 ,
290290 theta : f64 ,
291291 num_std_dev : NumStdDev ,
292292) -> Result < f64 , Error > {
293- check_theta ( theta) ?;
293+ if theta <= 0.0 || theta > 1.0 {
294+ return Err ( Error :: invalid_argument ( format ! (
295+ "theta must be in the range (0.0, 1.0], got {theta}"
296+ ) ) ) ;
297+ }
294298
295299 let estimate = num_samples as f64 / theta;
296300 let lb = compute_approx_binomial_lower_bound ( num_samples, theta, num_std_dev) ;
@@ -325,7 +329,12 @@ pub(crate) fn upper_bound(
325329 if no_data_seen {
326330 return Ok ( 0.0 ) ;
327331 }
328- check_theta ( theta) ?;
332+
333+ if theta <= 0.0 || theta > 1.0 {
334+ return Err ( Error :: invalid_argument ( format ! (
335+ "theta must be in the range (0.0, 1.0], got {theta}"
336+ ) ) ) ;
337+ }
329338
330339 let estimate = num_samples as f64 / theta;
331340 let ub = compute_approx_binomial_upper_bound ( num_samples, theta, num_std_dev) ;
@@ -360,6 +369,7 @@ fn cont_classic_ub(num_samples: u64, theta: f64, num_std_devs: f64) -> f64 {
360369/// # Limitations
361370///
362371/// Outside of the valid input range, two different bad things will happen:
372+ ///
363373/// 1. Because we are not using logarithms, the values of intermediate quantities will exceed the
364374/// dynamic range of doubles.
365375/// 2. Even if that problem were fixed, the running time of this procedure is essentially linear in
@@ -548,21 +558,6 @@ fn compute_approx_binomial_upper_bound(
548558 special_n_prime_f ( num_samples, theta, delta) . unwrap_or ( num_samples + 1 ) as f64 // no need to round
549559}
550560
551- /// Validates that theta is in the valid range [0.0, 1.0].
552- ///
553- /// # Errors
554- ///
555- /// Returns an error if theta < 0.0 or theta > 1.0.
556- fn check_theta ( theta : f64 ) -> Result < ( ) , Error > {
557- if ( theta <= 0.0 ) || ( theta > 1.0 ) {
558- return Err ( Error :: invalid_argument ( format ! (
559- "theta must be in the range [0.0, 1.0]: {}" ,
560- theta
561- ) ) ) ;
562- }
563- Ok ( ( ) )
564- }
565-
566561#[ cfg( test) ]
567562mod tests {
568563 use super :: * ;
@@ -679,53 +674,34 @@ mod tests {
679674 fn check_bounds ( ) {
680675 let mut i = 0 ;
681676
677+ fn assert_approx_equal ( ci : NumStdDev , j : usize , expected : f64 , actual : f64 ) {
678+ let ratio = actual / expected;
679+ assert ! (
680+ ( ratio - 1.0 ) . abs( ) < TOL ,
681+ "ci={ci:?}, j={j}: expected {expected}, got {actual}, ratio={ratio}" ,
682+ ) ;
683+ }
684+
682685 for ci in [ NumStdDev :: One , NumStdDev :: Two , NumStdDev :: Three ] {
683686 let arr = run_test_aux ( 20 , ci, 1e-3 ) ;
684687 for j in 0 ..5 {
685- let ratio = arr[ j] / STD [ i] [ j] ;
686- assert ! (
687- ( ratio - 1.0 ) . abs( ) < TOL ,
688- "ci={:?}, j={}: expected {}, got {}, ratio={}" ,
689- ci,
690- j,
691- STD [ i] [ j] ,
692- arr[ j] ,
693- ratio
694- ) ;
688+ assert_approx_equal ( ci, j, STD [ i] [ j] , arr[ j] ) ;
695689 }
696690 i += 1 ;
697691 }
698692
699693 for ci in [ NumStdDev :: One , NumStdDev :: Two , NumStdDev :: Three ] {
700694 let arr = run_test_aux ( 200 , ci, 1e-5 ) ;
701695 for j in 0 ..5 {
702- let ratio = arr[ j] / STD [ i] [ j] ;
703- assert ! (
704- ( ratio - 1.0 ) < TOL ,
705- "ci={:?}, j={}: expected {}, got {}, ratio={}" ,
706- ci,
707- j,
708- STD [ i] [ j] ,
709- arr[ j] ,
710- ratio
711- ) ;
696+ assert_approx_equal ( ci, j, STD [ i] [ j] , arr[ j] ) ;
712697 }
713698 i += 1 ;
714699 }
715700
716701 for ci in [ NumStdDev :: One , NumStdDev :: Two , NumStdDev :: Three ] {
717702 let arr = run_test_aux ( 2000 , ci, 1e-7 ) ;
718703 for j in 0 ..5 {
719- let ratio = arr[ j] / STD [ i] [ j] ;
720- assert ! (
721- ( ratio - 1.0 ) . abs( ) < TOL ,
722- "ci={:?}, j={}: expected {}, got {}, ratio={}" ,
723- ci,
724- j,
725- STD [ i] [ j] ,
726- arr[ j] ,
727- ratio
728- ) ;
704+ assert_approx_equal ( ci, j, STD [ i] [ j] , arr[ j] ) ;
729705 }
730706 i += 1 ;
731707 }
0 commit comments