1+ use crate :: testing:: utils:: accumulate_gene_statistics_two_groups;
12use crate :: testing:: { Alternative , TTestType , TestResult } ;
23use nalgebra_sparse:: CsrMatrix ;
4+ use rayon:: iter:: ParallelIterator ;
5+ use rayon:: prelude:: IntoParallelIterator ;
36use single_utilities:: traits:: { FloatOps , FloatOpsTS } ;
47use statrs:: distribution:: { ContinuousCDF , StudentsT } ;
58
@@ -17,47 +20,26 @@ where
1720 }
1821
1922 let n_genes = matrix. ncols ( ) ;
20- let mut results = Vec :: with_capacity ( n_genes) ;
21-
22- // Pre-allocate working vectors
23- let mut group1_values = Vec :: with_capacity ( group1_indices. len ( ) ) ;
24- let mut group2_values = Vec :: with_capacity ( group2_indices. len ( ) ) ;
25-
26- for gene_idx in 0 ..n_genes {
27- // Clear and reuse vectors
28- group1_values. clear ( ) ;
29- group2_values. clear ( ) ;
30-
31- // Extract values for this gene (column gene_idx)
32- // For each cell in group1, get the gene expression value
33- for & cell_idx in group1_indices {
34- let value = if let Some ( entry) = matrix. get_entry ( cell_idx, gene_idx) {
35- entry. into_value ( )
36- } else {
37- T :: zero ( ) // Handle sparse entries
38- } ;
39- group1_values. push ( value) ;
40- }
41-
42- // For each cell in group2, get the gene expression value
43- for & cell_idx in group2_indices {
44- let value = if let Some ( entry) = matrix. get_entry ( cell_idx, gene_idx) {
45- entry. into_value ( )
46- } else {
47- T :: zero ( ) // Handle sparse entries
48- } ;
49- group2_values. push ( value) ;
50- }
51-
52- // Run t-test for this gene
53- let result = t_test (
54- & group1_values,
55- & group2_values,
56- test_type,
57- Alternative :: TwoSided ,
58- ) ;
59- results. push ( result) ;
60- }
23+ let group1_size = T :: from ( group1_indices. len ( ) ) . unwrap ( ) ;
24+ let group2_size = T :: from ( group2_indices. len ( ) ) . unwrap ( ) ;
25+
26+ let ( group1_sums, group1_sum_squares, group2_sums, group2_sum_squares) =
27+ accumulate_gene_statistics_two_groups ( matrix, group1_indices, group2_indices) ?;
28+
29+ let results: Vec < TestResult < T > > = ( 0 ..n_genes)
30+ . into_iter ( )
31+ . map ( |gene_idx| {
32+ fast_t_test_from_sums (
33+ group1_sums[ gene_idx] ,
34+ group1_sum_squares[ gene_idx] ,
35+ group1_size,
36+ group2_sums[ gene_idx] ,
37+ group2_sum_squares[ gene_idx] ,
38+ group2_size,
39+ test_type,
40+ )
41+ } )
42+ . collect ( ) ;
6143
6244 Ok ( results)
6345}
@@ -73,89 +55,121 @@ where
7355 return TestResult :: new ( T :: zero ( ) , T :: one ( ) ) ;
7456 }
7557
76- // Calculate means
77- let sum_x: T = x. iter ( ) . copied ( ) . sum ( ) ;
78- let sum_y: T = y. iter ( ) . copied ( ) . sum ( ) ;
58+ let ( sum_x, sum_sq_x) = x
59+ . iter ( )
60+ . fold ( ( T :: zero ( ) , T :: zero ( ) ) , |( sum, sum_sq) , & val| {
61+ ( sum + val, sum_sq + val * val)
62+ } ) ;
63+
64+ let ( sum_y, sum_sq_y) = y
65+ . iter ( )
66+ . fold ( ( T :: zero ( ) , T :: zero ( ) ) , |( sum, sum_sq) , & val| {
67+ ( sum + val, sum_sq + val * val)
68+ } ) ;
7969
8070 let nx_f = T :: from ( nx) . unwrap ( ) ;
8171 let ny_f = T :: from ( ny) . unwrap ( ) ;
8272
83- let mean_x = sum_x / nx_f;
84- let mean_y = sum_y / ny_f;
73+ fast_t_test_from_sums ( sum_x, sum_sq_x, nx_f, sum_y, sum_sq_y, ny_f, test_type)
74+ }
75+
76+ fn fast_t_test_from_sums < T > (
77+ sum1 : T ,
78+ sum_sq1 : T ,
79+ n1 : T ,
80+ sum2 : T ,
81+ sum_sq2 : T ,
82+ n2 : T ,
83+ test_type : TTestType ,
84+ ) -> TestResult < T >
85+ where
86+ T : FloatOps ,
87+ {
88+ if n1 < T :: from ( 2.0 ) . unwrap ( ) || n2 < T :: from ( 2.0 ) . unwrap ( ) {
89+ return TestResult :: new ( T :: zero ( ) , T :: one ( ) ) ;
90+ }
8591
86- // Calculate sample variances using the correct formula
87- let var_x = x. iter ( )
88- . map ( |& val| ( val - mean_x) * ( val - mean_x) )
89- . sum :: < T > ( ) / ( nx_f - T :: one ( ) ) ;
92+ let mean1 = sum1 / n1;
93+ let mean2 = sum2 / n2;
9094
91- let var_y = y. iter ( )
92- . map ( |& val| ( val - mean_y) * ( val - mean_y) )
93- . sum :: < T > ( ) / ( ny_f - T :: one ( ) ) ;
95+ let var1 = ( sum_sq1 - sum1 * mean1) / ( n1 - T :: one ( ) ) ;
96+ let var2 = ( sum_sq2 - sum2 * mean2) / ( n2 - T :: one ( ) ) ;
9497
95- // Early exit for zero variance cases
96- if var_x <= T :: zero ( ) && var_y <= T :: zero ( ) {
97- if num_traits:: Float :: abs ( mean_x - mean_y) < <T as num_traits:: Float >:: epsilon ( ) {
98- return TestResult :: new ( T :: zero ( ) , T :: one ( ) ) ; // No difference, no variance
98+ if var1 <= T :: zero ( ) && var2 <= T :: zero ( ) {
99+ if num_traits:: Float :: abs ( mean1 - mean2) < <T as num_traits:: Float >:: epsilon ( ) {
100+ return TestResult :: new ( T :: zero ( ) , T :: one ( ) ) ;
99101 } else {
100- return TestResult :: new ( <T as num_traits:: Float >:: infinity ( ) , T :: zero ( ) ) ; // Infinite t-stat, highly significant
102+ return TestResult :: new ( <T as num_traits:: Float >:: infinity ( ) , T :: zero ( ) ) ;
101103 }
102104 }
103105
104106 let ( t_stat, df) = match test_type {
105107 TTestType :: Student => {
106- // Pooled variance (equal variances assumed)
107- let pooled_var = ( ( nx_f - T :: one ( ) ) * var_x + ( ny_f - T :: one ( ) ) * var_y)
108- / ( nx_f + ny_f - T :: from ( 2.0 ) . unwrap ( ) ) ;
108+ let pooled_var = ( ( n1 - T :: one ( ) ) * var1 + ( n2 - T :: one ( ) ) * var2)
109+ / ( n1 + n2 - T :: from ( 2.0 ) . unwrap ( ) ) ;
109110
110111 if pooled_var <= T :: zero ( ) {
111112 return TestResult :: new ( <T as num_traits:: Float >:: infinity ( ) , T :: zero ( ) ) ;
112113 }
113114
114- let std_err = ( pooled_var * ( T :: one ( ) / nx_f + T :: one ( ) / ny_f) ) . sqrt ( ) ;
115- let t = ( mean_x - mean_y) / std_err;
116- let degrees_freedom = nx_f + ny_f - T :: from ( 2.0 ) . unwrap ( ) ;
117- ( t, degrees_freedom)
115+ let std_err = ( pooled_var * ( T :: one ( ) / n1 + T :: one ( ) / n2) ) . sqrt ( ) ;
116+ let t = ( mean1 - mean2) / std_err;
117+ ( t, n1 + n2 - T :: from ( 2.0 ) . unwrap ( ) )
118118 }
119119 TTestType :: Welch => {
120- // Welch's t-test (unequal variances)
121- let term1 = var_x / nx_f;
122- let term2 = var_y / ny_f;
120+ let term1 = var1 / n1;
121+ let term2 = var2 / n2;
123122 let combined_var = term1 + term2;
124123
125124 if combined_var <= T :: zero ( ) {
126125 return TestResult :: new ( <T as num_traits:: Float >:: infinity ( ) , T :: zero ( ) ) ;
127126 }
128127
129128 let std_err = combined_var. sqrt ( ) ;
130- let t = ( mean_x - mean_y ) / std_err;
129+ let t = ( mean1 - mean2 ) / std_err;
131130
132- // Welch-Satterthwaite equation for degrees of freedom
131+ // Welch-Satterthwaite equation
133132 let df = combined_var * combined_var
134- / ( term1 * term1 / ( nx_f - T :: one ( ) ) + term2 * term2 / ( ny_f - T :: one ( ) ) ) ;
133+ / ( term1 * term1 / ( n1 - T :: one ( ) ) + term2 * term2 / ( n2 - T :: one ( ) ) ) ;
135134 ( t, df)
136135 }
137136 } ;
138137
139- // Handle edge cases
138+ let p_value = fast_t_test_p_value ( t_stat, df) ;
139+
140+ TestResult :: new ( t_stat, p_value)
141+ }
142+
143+ fn fast_t_test_p_value < T > ( t_stat : T , df : T ) -> T
144+ where
145+ T : FloatOps ,
146+ {
140147 if !num_traits:: Float :: is_finite ( t_stat) {
141- return TestResult :: new (
142- t_stat,
143- if num_traits:: Float :: is_infinite ( t_stat) {
144- T :: zero ( )
145- } else {
146- T :: one ( )
147- } ,
148- ) ;
148+ return if num_traits:: Float :: is_infinite ( t_stat) {
149+ T :: zero ( )
150+ } else {
151+ T :: one ( )
152+ } ;
149153 }
150154
151155 if df <= T :: zero ( ) || !num_traits:: Float :: is_finite ( df) {
152- return TestResult :: new ( t_stat , T :: one ( ) ) ;
156+ return T :: one ( ) ;
153157 }
154158
155- // Calculate p-value using t-distribution
156- let p_value = calculate_p_value ( t_stat, df, alternative) ;
159+ if df > T :: from ( 30.0 ) . unwrap ( ) {
160+ let abs_t = num_traits:: Float :: abs ( t_stat) ;
161+ return T :: from ( 2.0 ) . unwrap ( ) * normal_cdf_complement ( abs_t) ;
162+ }
163+
164+ let t_f64 = t_stat. to_f64 ( ) . unwrap ( ) ;
165+ let df_f64 = df. to_f64 ( ) . unwrap ( ) ;
157166
158- TestResult :: new ( t_stat, num_traits:: Float :: clamp ( p_value, T :: zero ( ) , T :: one ( ) ) )
167+ match StudentsT :: new ( 0.0 , 1.0 , df_f64) {
168+ Ok ( t_dist) => {
169+ T :: from ( 2.0 ) . unwrap ( ) * ( T :: one ( ) - T :: from ( t_dist. cdf ( t_f64. abs ( ) ) ) . unwrap ( ) )
170+ }
171+ Err ( _) => T :: one ( ) ,
172+ }
159173}
160174
161175fn calculate_p_value < T > ( t_stat : T , df : T , alternative : Alternative ) -> T
@@ -168,18 +182,9 @@ where
168182 match StudentsT :: new ( 0.0 , 1.0 , df_f64) {
169183 Ok ( t_dist) => {
170184 let p = match alternative {
171- Alternative :: TwoSided => {
172- // Two-tailed test
173- 2.0 * ( 1.0 - t_dist. cdf ( t_f64. abs ( ) ) )
174- }
175- Alternative :: Less => {
176- // Left-tailed test: P(T <= t)
177- t_dist. cdf ( t_f64)
178- }
179- Alternative :: Greater => {
180- // Right-tailed test: P(T >= t)
181- 1.0 - t_dist. cdf ( t_f64)
182- }
185+ Alternative :: TwoSided => 2.0 * ( 1.0 - t_dist. cdf ( t_f64. abs ( ) ) ) ,
186+ Alternative :: Less => t_dist. cdf ( t_f64) ,
187+ Alternative :: Greater => 1.0 - t_dist. cdf ( t_f64) ,
183188 } ;
184189 T :: from ( p) . unwrap ( )
185190 }
@@ -200,7 +205,6 @@ where
200205 let df_f64 = df. to_f64 ( ) . unwrap ( ) ;
201206 let p_f64 = p. to_f64 ( ) . unwrap ( ) ;
202207
203- // Create a Student's t distribution with the specified degrees of freedom
204208 match StudentsT :: new ( 0.0 , 1.0 , df_f64) {
205209 Ok ( dist) => T :: from ( dist. inverse_cdf ( p_f64) ) . unwrap ( ) ,
206210 Err ( _) => panic ! ( "Failed to create StudentsT distribution" ) ,
0 commit comments