Skip to content

Commit 92d100d

Browse files
author
Ian
committed
Boosted performance
Added generic types for tests version bump
1 parent 9d15840 commit 92d100d

File tree

3 files changed

+55
-47
lines changed

3 files changed

+55
-47
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "single-statistics"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
edition = "2024"
55
license-file = "LICENSE.md"
66
readme = "README.md"

src/testing/inference/parametric.rs

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,28 @@ where
2828
group1_values.clear();
2929
group2_values.clear();
3030

31-
// Extract values for this gene
31+
// Extract values for this gene (column gene_idx)
32+
// For each cell in group1, get the gene expression value
3233
for &cell_idx in group1_indices {
3334
let value = if let Some(entry) = matrix.get_entry(cell_idx, gene_idx) {
3435
entry.into_value()
3536
} else {
36-
T::zero()
37+
T::zero() // Handle sparse entries
3738
};
3839
group1_values.push(value);
3940
}
4041

42+
// For each cell in group2, get the gene expression value
4143
for &cell_idx in group2_indices {
4244
let value = if let Some(entry) = matrix.get_entry(cell_idx, gene_idx) {
4345
entry.into_value()
4446
} else {
45-
T::zero()
47+
T::zero() // Handle sparse entries
4648
};
4749
group2_values.push(value);
4850
}
4951

50-
// Run optimized t-test
52+
// Run t-test for this gene
5153
let result = t_test(
5254
&group1_values,
5355
&group2_values,
@@ -71,52 +73,51 @@ where
7173
return TestResult::new(T::zero(), T::one());
7274
}
7375

74-
// Single-pass calculation of means and sums of squares
75-
let (sum_x, sum_sq_x) = x
76-
.iter()
77-
.fold((T::zero(), T::zero()), |(sum, sum_sq), &val| {
78-
(sum + val, sum_sq + val * val)
79-
});
80-
81-
let (sum_y, sum_sq_y) = y
82-
.iter()
83-
.fold((T::zero(), T::zero()), |(sum, sum_sq), &val| {
84-
(sum + val, sum_sq + val * val)
85-
});
76+
// Calculate means
77+
let sum_x: T = x.iter().copied().sum();
78+
let sum_y: T = y.iter().copied().sum();
8679

8780
let nx_f = T::from(nx).unwrap();
8881
let ny_f = T::from(ny).unwrap();
8982

9083
let mean_x = sum_x / nx_f;
9184
let mean_y = sum_y / ny_f;
9285

93-
// Calculate variances more efficiently
94-
let var_x = (sum_sq_x - sum_x * mean_x) / (nx_f - T::one());
95-
let var_y = (sum_sq_y - sum_y * mean_y) / (ny_f - T::one());
86+
// Calculate sample variances using the correct formula
87+
let var_x = x.iter()
88+
.map(|&val| (val - mean_x) * (val - mean_x))
89+
.sum::<T>() / (nx_f - T::one());
90+
91+
let var_y = y.iter()
92+
.map(|&val| (val - mean_y) * (val - mean_y))
93+
.sum::<T>() / (ny_f - T::one());
9694

9795
// Early exit for zero variance cases
9896
if var_x <= T::zero() && var_y <= T::zero() {
9997
if num_traits::Float::abs(mean_x - mean_y) < <T as num_traits::Float>::epsilon() {
10098
return TestResult::new(T::zero(), T::one()); // No difference, no variance
10199
} else {
102-
return TestResult::new(<T as num_traits::Float>::infinity(), T::one()); // Infinite t-stat
100+
return TestResult::new(<T as num_traits::Float>::infinity(), T::zero()); // Infinite t-stat, highly significant
103101
}
104102
}
105103

106104
let (t_stat, df) = match test_type {
107105
TTestType::Student => {
106+
// Pooled variance (equal variances assumed)
108107
let pooled_var = ((nx_f - T::one()) * var_x + (ny_f - T::one()) * var_y)
109-
/ (nx_f + ny_f - T::from_f64(2.0).unwrap());
108+
/ (nx_f + ny_f - T::from(2.0).unwrap());
110109

111110
if pooled_var <= T::zero() {
112111
return TestResult::new(<T as num_traits::Float>::infinity(), T::zero());
113112
}
114113

115114
let std_err = (pooled_var * (T::one() / nx_f + T::one() / ny_f)).sqrt();
116115
let t = (mean_x - mean_y) / std_err;
117-
(t, nx_f + ny_f - T::from(2.0).unwrap())
116+
let degrees_freedom = nx_f + ny_f - T::from(2.0).unwrap();
117+
(t, degrees_freedom)
118118
}
119119
TTestType::Welch => {
120+
// Welch's t-test (unequal variances)
120121
let term1 = var_x / nx_f;
121122
let term2 = var_y / ny_f;
122123
let combined_var = term1 + term2;
@@ -151,32 +152,39 @@ where
151152
return TestResult::new(t_stat, T::one());
152153
}
153154

154-
// Use pre-computed t-distribution if possible, or fall back to normal approximation for large df
155-
let p_value = if df > T::from(100.0).unwrap() {
156-
// Normal approximation for large degrees of freedom (faster)
157-
let abs_t = num_traits::Float::abs(t_stat);
158-
match alternative {
159-
Alternative::TwoSided => T::from(2.0).unwrap() * normal_cdf_complement(abs_t),
160-
Alternative::Less => normal_cdf(-t_stat),
161-
Alternative::Greater => normal_cdf_complement(t_stat),
162-
}
163-
} else {
164-
// Use t-distribution for smaller degrees of freedom
165-
let tstat_f64 = t_stat.to_f64().unwrap();
166-
match StudentsT::new(0.0, 1.0, df.to_f64().unwrap()) {
167-
Ok(t_dist) => match alternative {
155+
// Calculate p-value using t-distribution
156+
let p_value = calculate_p_value(t_stat, df, alternative);
157+
158+
TestResult::new(t_stat, num_traits::Float::clamp(p_value, T::zero(), T::one()))
159+
}
160+
161+
fn calculate_p_value<T>(t_stat: T, df: T, alternative: Alternative) -> T
162+
where
163+
T: FloatOps,
164+
{
165+
let t_f64 = t_stat.to_f64().unwrap();
166+
let df_f64 = df.to_f64().unwrap();
167+
168+
match StudentsT::new(0.0, 1.0, df_f64) {
169+
Ok(t_dist) => {
170+
let p = match alternative {
168171
Alternative::TwoSided => {
169-
T::from(2.0).unwrap()
170-
* (T::one() - T::from(t_dist.cdf(tstat_f64.abs())).unwrap())
172+
// Two-tailed test
173+
2.0 * (1.0 - t_dist.cdf(t_f64.abs()))
171174
}
172-
Alternative::Less => T::from(t_dist.cdf(-tstat_f64)).unwrap(),
173-
Alternative::Greater => T::one() - T::from(t_dist.cdf(tstat_f64)).unwrap(),
174-
},
175-
Err(_) => T::one(), // Fallback
175+
Alternative::Less => {
176+
// Left-tailed test: P(T <= t)
177+
t_dist.cdf(t_f64)
178+
}
179+
Alternative::Greater => {
180+
// Right-tailed test: P(T >= t)
181+
1.0 - t_dist.cdf(t_f64)
182+
}
183+
};
184+
T::from(p).unwrap()
176185
}
177-
};
178-
179-
TestResult::new(t_stat, num_traits::Float::clamp(p_value, T::zero(), T::one()))
186+
Err(_) => T::one(), // Fallback for invalid parameters
187+
}
180188
}
181189

182190
pub fn student_t_quantile<T>(p: T, df: T) -> T

0 commit comments

Comments
 (0)