1
- use crate :: cosine_zarr:: { initialize_cosine_zarrstore, write_cosine_to_zarr, ZarrChunkInfo } ;
2
1
use crate :: gridcounts:: GridCounts ;
3
2
use crate :: sparsekde:: sparse_kde_csx_;
4
3
use crate :: utils:: create_pool;
@@ -13,23 +12,20 @@ use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2};
13
12
use pyo3:: { exceptions:: PyValueError , prelude:: * } ;
14
13
use rayon:: prelude:: * ;
15
14
use sprs:: { CompressedStorage :: CSR , CsMatI , CsMatViewI , SpIndex } ;
16
- use std:: { cmp:: min, error:: Error , ops:: Range , path:: PathBuf } ;
17
- use zarrs:: array:: Element ;
15
+ use std:: { cmp:: min, error:: Error , ops:: Range } ;
18
16
19
17
macro_rules! build_cos_ct_fn {
20
18
( $name: tt, $t_cos: ty, $t_ct: ty) => {
21
19
#[ pyfunction]
22
- #[ pyo3( signature = ( counts, genes, celltypes , signatures, kernel, * , log=false , zarr_path= None , chunk_size=( 500 , 500 ) , n_threads=None ) ) ]
20
+ #[ pyo3( signature = ( counts, genes, signatures, kernel, * , log=false , chunk_size=( 500 , 500 ) , n_threads=None ) ) ]
23
21
/// calculate cosine similarity and assign celltype
24
22
pub fn $name<' py>(
25
23
py: Python <' py>,
26
24
counts: & mut GridCounts ,
27
25
genes: Vec <String >,
28
- celltypes: Vec <String >,
29
26
signatures: PyReadonlyArray2 <' py, $t_cos>,
30
27
kernel: PyReadonlyArray2 <' py, $t_cos>,
31
28
log: bool ,
32
- zarr_path: Option <PathBuf >,
33
29
chunk_size: ( usize , usize ) ,
34
30
n_threads: Option <usize >,
35
31
) -> PyResult <(
@@ -50,12 +46,10 @@ macro_rules! build_cos_ct_fn {
50
46
51
47
let cos_ct = chunk_and_calculate_cosine(
52
48
& gene_counts,
53
- celltypes,
54
49
signatures. as_array( ) ,
55
50
kernel. as_array( ) ,
56
51
counts. shape,
57
52
log,
58
- zarr_path,
59
53
chunk_size,
60
54
n_threads
61
55
) ;
@@ -77,19 +71,17 @@ build_cos_ct_fn!(cosinef32_and_celltypei16, f32, i16);
77
71
78
72
fn chunk_and_calculate_cosine < C , I , F , U > (
79
73
counts : & [ CsMatViewI < C , I > ] ,
80
- celltypes : Vec < String > ,
81
74
signatures : ArrayView2 < F > ,
82
75
kernel : ArrayView2 < F > ,
83
76
shape : ( usize , usize ) ,
84
77
log : bool ,
85
- zarr_path : Option < PathBuf > ,
86
78
chunk_size : ( usize , usize ) ,
87
79
n_threads : Option < usize > ,
88
- ) -> Result < ( Array2 < F > , Array2 < F > , Array2 < U > ) , Box < dyn Error + Send + Sync > >
80
+ ) -> Result < ( Array2 < F > , Array2 < F > , Array2 < U > ) , Box < dyn Error > >
89
81
where
90
82
C : NumCast + Copy + Sync + Send + Default ,
91
83
I : SpIndex + Signed + Sync + Send ,
92
- F : NdFloat + Element ,
84
+ F : NdFloat ,
93
85
U : PrimInt + Signed + Sync + Send ,
94
86
Slice : From < Range < I > > ,
95
87
{
@@ -123,16 +115,7 @@ where
123
115
}
124
116
} ) ;
125
117
126
- // init zarr store for celltypes with chunksize and all zero arrays
127
- let zarr_store = match zarr_path
128
- . map ( |path| initialize_cosine_zarrstore ( path, & celltypes, shape, chunk_size) )
129
- {
130
- Some ( Err ( e) ) => return Err ( e) ,
131
- Some ( Ok ( store) ) => Some ( store) ,
132
- None => None ,
133
- } ;
134
-
135
- let celltyping_results = pool. install ( || {
118
+ let ( ( cosine, score) , celltype) : ( ( Vec < _ > , Vec < _ > ) , Vec < _ > ) = pool. install ( || {
136
119
// generate all chunk indices
137
120
let chunk_indices: Vec < _ > = ( 0 ..m) . cartesian_product ( 0 ..n) . collect ( ) ;
138
121
@@ -142,28 +125,18 @@ where
142
125
. map ( |idx| {
143
126
let ( chunk, unpad) = get_chunk ( counts, idx, shape, chunk_size, pad) ;
144
127
145
- let zarr_info = zarr_store. clone ( ) . map ( |store| ZarrChunkInfo {
146
- store,
147
- celltypes : { celltypes. clone ( ) } ,
148
- chunk_idx : vec ! [ idx. 0 as u64 , idx. 1 as u64 ] ,
149
- } ) ;
150
-
151
128
cosine_and_celltype_ (
152
129
chunk,
153
130
signatures,
154
131
& signature_similarity_correction,
155
132
kernel,
156
133
unpad,
157
134
log,
158
- zarr_info,
159
135
)
160
136
} )
161
- . collect :: < Vec < _ > > ( )
137
+ . unzip ( )
162
138
} ) ;
163
139
164
- let ( ( cosine, score) , celltype) : ( ( Vec < _ > , Vec < _ > ) , Vec < _ > ) =
165
- itertools:: process_results ( celltyping_results, |iter| iter. unzip ( ) ) ?;
166
-
167
140
// concatenate all chunks back to original shape
168
141
Ok ( (
169
142
concat_2d ( & cosine, n) ?,
@@ -234,11 +207,10 @@ fn cosine_and_celltype_<C, I, F, U>(
234
207
kernel : ArrayView2 < F > ,
235
208
unpad : ( Range < usize > , Range < usize > ) ,
236
209
log : bool ,
237
- zarr_info : Option < ZarrChunkInfo > ,
238
- ) -> Result < ( ( Array2 < F > , Array2 < F > ) , Array2 < U > ) , Box < dyn Error + Send + Sync > >
210
+ ) -> ( ( Array2 < F > , Array2 < F > ) , Array2 < U > )
239
211
where
240
212
C : NumCast + Copy ,
241
- F : NdFloat + Element ,
213
+ F : NdFloat ,
242
214
U : PrimInt + Signed ,
243
215
I : SpIndex + Signed ,
244
216
Slice : From < Range < I > > ,
@@ -253,10 +225,10 @@ where
253
225
// fastpath if all csx are empty
254
226
None => {
255
227
let shape = ( unpad_r. end - unpad_r. start , unpad_c. end - unpad_c. start ) ;
256
- Ok ( (
228
+ (
257
229
( Array2 :: zeros ( shape) , Array2 :: zeros ( shape) ) ,
258
230
Array2 :: from_elem ( shape, -one :: < U > ( ) ) ,
259
- ) )
231
+ )
260
232
}
261
233
Some ( ( csx, weights) ) => {
262
234
let shape = csx. shape ( ) ;
@@ -290,24 +262,8 @@ where
290
262
. filter ( |( _, & w) | w != zero :: < F > ( ) )
291
263
. for_each ( |( mut cos, & w) | cos += & kde_unpadded. map ( |& x| x * w) ) ;
292
264
}
293
-
294
- kde_norm. mapv_inplace ( F :: sqrt) ;
295
-
296
- if let Some ( zarr_info) = zarr_info {
297
- write_cosine_to_zarr (
298
- zarr_info. store ,
299
- & cosine,
300
- & kde_norm,
301
- & zarr_info. celltypes ,
302
- & zarr_info. chunk_idx ,
303
- ) ?
304
- } ;
305
-
306
- Ok ( get_max_cosine_and_celltype (
307
- cosine,
308
- kde_norm,
309
- pairwise_correction,
310
- ) )
265
+ // TODO: write to zarr
266
+ get_max_cosine_and_celltype ( cosine, kde_norm, pairwise_correction)
311
267
}
312
268
}
313
269
}
@@ -335,8 +291,9 @@ where
335
291
* ct = -one :: < I > ( ) ;
336
292
* s = zero ( ) ;
337
293
} else {
338
- * cos /= norm;
339
- * s /= norm;
294
+ let norm_sqrt = norm. sqrt ( ) ;
295
+ * cos /= norm_sqrt;
296
+ * s /= norm_sqrt;
340
297
} ;
341
298
} ) ;
342
299
0 commit comments