1
+ from importlib .resources import path
1
2
from textwrap import fill
2
3
from threading import local
3
4
4
5
import numpy as np
5
6
# from numpy.ma.extras import _covhelper
6
- # from numpy.ma.core import dot
7
+ # from numpy.ma.core import dot
7
8
8
9
from requests import patch
9
10
from scipy .ndimage import gaussian_filter , maximum_filter
10
11
from plankton .pixelmaps import PixelMap
11
12
12
-
13
+ from sklearn .neighbors import NearestNeighbors
14
+ from sklearn .decomposition import FastICA
15
+ from sklearn .cluster import KMeans
13
16
14
17
import matplotlib .pyplot as plt
15
18
19
+ from tqdm import tqdm_notebook
16
20
17
- def determine_gains (stats1 ,stats2 ):
21
+ def determine_gains (stats1 , stats2 ):
18
22
19
23
norm_counts1 = stats1 .counts / stats1 .counts .sum ()
20
24
norm_counts2 = stats2 .counts / stats2 .counts .sum ()
21
-
25
+
22
26
return norm_counts1 / norm_counts2
23
27
24
28
@@ -54,6 +58,7 @@ def hbar_compare(stat1, stat2, labels=None, text_display_threshold=0.02, c=None)
54
58
if labels is not None :
55
59
plt .xticks ((0 , 1 ), labels )
56
60
61
+
57
62
def sorted_bar_compare (stat1 , stat2 , kwargs1 = {}, kwargs2 = {}):
58
63
categories_1 = (stat1 .index )
59
64
counts_1 = (np .array (stat1 .counts ).flatten ())
@@ -169,6 +174,7 @@ def sorted_bar_compare(stat1, stat2, kwargs1={}, kwargs2={}):
169
174
# ax3.xaxis.set_label_position('top')
170
175
# ax3.set_ylabel('log(count) spatial')
171
176
177
+
172
178
def fill_celltypemaps (ct_map , fill_blobs = True , min_blob_area = 0 , filter_params = {}, output_mask = None ):
173
179
"""
174
180
Post-filter cell type maps created by `map_celltypes`.
@@ -191,25 +197,25 @@ def fill_celltypemaps(ct_map, fill_blobs=True, min_blob_area=0, filter_params={}
191
197
"""
192
198
193
199
from skimage import measure
194
-
195
-
200
+
196
201
filtered_ctmaps = np .zeros_like (ct_map ) - 1
197
202
198
203
for cidx in np .unique (ct_map ):
199
- mask = ct_map == cidx
204
+ mask = ct_map == cidx
200
205
if min_blob_area > 0 or fill_blobs :
201
206
blob_labels = measure .label (mask , background = 0 )
202
207
for bp in measure .regionprops (blob_labels ):
203
208
if min_blob_area > 0 and bp .filled_area < min_blob_area :
204
209
for c in bp .coords :
205
- mask [c [0 ], c [1 ],] = 0
206
-
210
+ mask [c [0 ], c [1 ], ] = 0
211
+
207
212
continue
208
213
if fill_blobs and bp .area != bp .filled_area :
209
214
minx , miny , maxx , maxy , = bp .bbox
210
- mask [minx :maxx , miny :maxy ,] |= bp .filled_image
215
+ mask [minx :maxx , miny :maxy , ] |= bp .filled_image
211
216
212
- filtered_ctmaps [np .logical_and (mask == 1 , np .logical_or (ct_map == - 1 , ct_map == cidx ))] = cidx
217
+ filtered_ctmaps [np .logical_and (mask == 1 , np .logical_or (
218
+ ct_map == - 1 , ct_map == cidx ))] = cidx
213
219
214
220
return filtered_ctmaps
215
221
@@ -240,87 +246,128 @@ def get_histograms(sdata, mins=None, maxs=None, category=None, resolution=5):
240
246
241
247
return histograms ,
242
248
243
- def crosscorr (x ,y ):
244
- x -= x .mean (1 )[:,None ]
245
- y -= y .mean (1 )[:,None ]
246
249
247
- c = (np .dot (x , y .T )/ x .shape [1 ] ).squeeze ()
248
- return c / x .std (1 )[:,None ]/ y .std (1 )
250
+ def crosscorr (x , y ):
251
+ x -= x .mean (1 )[:, None ]
252
+ y -= y .mean (1 )[:, None ]
253
+ c = (np .dot (x , y .T )/ x .shape [1 ]).squeeze ()
254
+
255
+ return np .nan_to_num (np .nan_to_num (c / np .array (x .std (1 ))[:, None ])/ np .array (y .std (1 ))[None , :])
249
256
250
257
251
258
def ssam (sdata , signatures = None , adata_obs_label = 'celltype' , kernel_bandwidth = 2.5 , output_um_p_px = 5 ,
252
- patch_length = 1000 , threshold_exp = 0.1 , threshold_cor = 0.1 , background_value = - 1 ,
253
- fill_blobs = True ,min_blob_area = 10 ):
259
+ patch_length = 1000 , threshold_exp = 0.1 , threshold_cor = 0.1 , background_value = - 1 ,
260
+ fill_blobs = True , min_blob_area = 10 ):
254
261
255
262
if (sdata .scanpy is not None ) and (signatures is None ):
256
263
signatures = sdata .scanpy .generate_signatures (adata_obs_label )
257
264
258
- kernel_bandwidth_px = kernel_bandwidth / output_um_p_px
265
+ kernel_bandwidth_px = kernel_bandwidth / output_um_p_px
259
266
260
267
out_shape = (np .ceil (sdata .x .max ()/ output_um_p_px + kernel_bandwidth_px * 3 ).astype (int ),
261
268
np .ceil (sdata .y .max ()/ output_um_p_px + kernel_bandwidth_px * 3 ).astype (int ))
262
269
263
270
ct_map = np .zeros ((out_shape ), dtype = int )+ background_value
264
271
vf_norm = np .zeros_like (ct_map )
265
272
266
- range_x = np .floor (sdata .x .min ()).astype (int ),np .ceil (sdata .x .max ()).astype (int )
267
- patch_delimiters_x = list (range (range_x [0 ],range_x [1 ],patch_length ))+ [range_x [1 ]]
273
+ range_x = np .floor (sdata .x .min ()).astype (
274
+ int ), np .ceil (sdata .x .max ()).astype (int )
275
+ patch_delimiters_x = list (
276
+ range (range_x [0 ], range_x [1 ], patch_length ))+ [range_x [1 ]]
277
+
278
+ range_y = np .floor (sdata .y .min ()).astype (
279
+ int ), np .ceil (sdata .y .max ()).astype (int )
280
+ patch_delimiters_y = list (
281
+ range (range_y [0 ], range_y [1 ], patch_length ))+ [range_y [1 ]]
282
+
283
+ print (list (patch_delimiters_x ), list (patch_delimiters_y ))
284
+
285
+ with tqdm_notebook (total = (len (patch_delimiters_x )- 1 )* (len (patch_delimiters_y )- 1 )) as pbar :
286
+ for i , x in enumerate (patch_delimiters_x [:- 1 ]):
287
+ for j , y in enumerate (patch_delimiters_y [:- 1 ]):
288
+ sdata_patch = sdata .raw ().spatial [x :patch_delimiters_x [i + 1 ],
289
+ y :patch_delimiters_y [j + 1 ]]
290
+
291
+ if not len (sdata_patch ):
292
+ break
293
+
294
+ mins = sdata_patch .coordinates .min (0 ).astype (int )
295
+ maxs = sdata_patch .coordinates .max (
296
+ 0 ).astype (int )+ kernel_bandwidth * 3
297
+ hists = get_histograms (
298
+ sdata_patch , mins = mins , maxs = maxs , resolution = output_um_p_px )
299
+ hists = np .concatenate (
300
+ [gaussian_filter (h , kernel_bandwidth_px ) for h in hists ])
301
+
302
+ # print(hists.shape)
303
+
304
+ norm = np .sum (hists , axis = 0 )
305
+
306
+ vf_norm [(x + mins [0 ])// output_um_p_px :(x + mins [0 ])// output_um_p_px + norm .shape [0 ],
307
+ (y + mins [1 ])// output_um_p_px :(y + mins [1 ])// output_um_p_px + norm .shape [1 ], ] = norm
308
+
309
+ mask = norm > threshold_exp
310
+
311
+ exps = np .zeros ((len (sdata .genes ), mask .sum ()))
312
+
313
+ # print(exps.shape,signatures.shape)
314
+ exps [sdata .stats .loc [sdata_patch .genes ].gene_ids .values ,
315
+ :] = hists [:, mask ]
316
+
317
+ local_ct_map = mask .astype (int )- 1
318
+
319
+ corrs = crosscorr (exps .T , signatures )
320
+ # print(corrs.min())
321
+ corrs_winners = corrs .argmax (1 )
322
+ corrs_winners [corrs .max (1 ) < threshold_cor ] = background_value
323
+ local_ct_map [mask ] = corrs_winners
268
324
269
- range_y = np .floor (sdata .y .min ()).astype (int ),np .ceil (sdata .y .max ()).astype (int )
270
- patch_delimiters_y = list (range (range_y [0 ],range_y [1 ],patch_length ))+ [range_y [1 ]]
325
+ # idcs=sdata_patch.index
271
326
272
- print (list (patch_delimiters_x ),list (patch_delimiters_y ))
327
+ x_ = int (max (0 , (x + mins [0 ])- kernel_bandwidth_px * 3 ))// output_um_p_px
328
+ y_ = int (max (0 , (y + mins [1 ])- kernel_bandwidth_px * 3 ))// output_um_p_px
273
329
274
- for i ,x in enumerate (patch_delimiters_x [:- 1 ]):
275
- for j ,y in enumerate (patch_delimiters_y [:- 1 ]):
276
- sdata_patch = sdata .raw ().spatial [x :patch_delimiters_x [i + 1 ],
277
- y :patch_delimiters_y [j + 1 ]]
330
+ ct_map [x_ :x_ + local_ct_map .shape [0 ],
331
+ y_ :y_ + local_ct_map .shape [1 ]] = local_ct_map
278
332
279
- if not len (sdata_patch ): break
280
-
281
- mins = sdata_patch .coordinates .min (0 ).astype (int )
282
- maxs = sdata_patch .coordinates .max (0 ).astype (int )+ kernel_bandwidth * 3
283
- hists = get_histograms (sdata_patch ,mins = mins ,maxs = maxs , resolution = output_um_p_px )
284
- hists = np .concatenate ([gaussian_filter (h , kernel_bandwidth_px ) for h in hists ])
333
+ pbar .update (1 )
285
334
286
- print (hists .shape )
335
+ ct_map = fill_celltypemaps (
336
+ ct_map , min_blob_area = min_blob_area , fill_blobs = fill_blobs )
337
+ return PixelMap (ct_map .T , px_p_um = 1 / output_um_p_px )
287
338
288
- norm = np .sum (hists , axis = 0 )
289
339
290
- vf_norm [( x + mins [ 0 ]) // output_um_p_px :( x + mins [ 0 ]) // output_um_p_px + norm . shape [ 0 ],
291
- ( y + mins [ 1 ]) // output_um_p_px :( y + mins [ 1 ]) // output_um_p_px + norm . shape [ 1 ],] = norm
340
+ def localmax_sampling ( sdata , n_clusters = 10 , min_distance = 3 , bandwidth = 4 ):
341
+ n_bins = np . array ( sdata . spatial . shape )
292
342
293
- mask = norm > threshold_exp
343
+ vf = (gaussian_filter (np .histogram2d (
344
+ * sdata .coordinates .T , bins = n_bins )[0 ], 2 ))
345
+ localmaxs = np .where ((maximum_filter (vf , min_distance ) == vf ) & (vf > 0.2 ))
346
+ knn = NearestNeighbors (n_neighbors = 150 )
347
+ knn .fit (sdata .coordinates )
348
+ dists , nbrs = knn .kneighbors (np .array (localmaxs ).T )
349
+ neighbor_types = np .array (sdata .gene_ids )[nbrs ]
294
350
295
- exps = np .zeros ((len (sdata .genes ), mask . sum ( )))
351
+ counts = np .zeros ((dists . shape [ 0 ], len (sdata .genes )))
296
352
297
- print ( exps . shape , signatures . shape )
298
- exps [ sdata . stats . loc [ sdata_patch . genes ]. gene_ids . values ,:] = hists [:, mask ]
353
+ bandwidth = bandwidth
354
+ def kernel ( x ): return np . exp ( - x ** 2 / ( 2 * bandwidth ** 2 ))
299
355
300
- # signatures -= signatures.mean(0)
301
- # signatures /= signatures.std(0)
356
+ for i in range (0 , dists .shape [1 ]):
357
+ counts [np .arange (dists .shape [0 ]), neighbor_types [:, i ]
358
+ ] += kernel (dists [:, i ])
302
359
303
- # exps -= exps.mean(0)
304
- # exps -= exps.std(0)
360
+ assert (all (counts .sum (1 )) > 0 )
305
361
306
- local_ct_map = mask .astype (int )- 1
307
- # corrs = np.inner(exps.T, signatures)
362
+ counts = counts / counts .sum (1 )[:,None ]
308
363
309
- corrs = crosscorr (exps .T ,signatures )
310
- print (corrs .min ())
311
- corrs_winners = corrs .argmax (1 )
312
- corrs_winners [corrs .max (1 )< threshold_cor ]= background_value
313
- local_ct_map [mask ] = corrs_winners
364
+ ica = FastICA (n_components = 30 )
365
+ facs = ica .fit_transform (counts )
314
366
315
- # idcs=sdata_patch.index
367
+ km = KMeans (
368
+ n_clusters = n_clusters ,
369
+ max_iter = 100 ,
370
+ n_init = 1 ,).fit (counts [:])
316
371
372
+ return km .cluster_centers_
317
373
318
- x_ = int (max (0 ,(x + mins [0 ])- kernel_bandwidth_px * 3 ))// output_um_p_px
319
- y_ = int (max (0 ,(y + mins [1 ])- kernel_bandwidth_px * 3 ))// output_um_p_px
320
-
321
-
322
- ct_map [x_ :x_ + local_ct_map .shape [0 ],
323
- y_ :y_ + local_ct_map .shape [1 ]]= local_ct_map
324
-
325
- ct_map = fill_celltypemaps (ct_map ,min_blob_area = min_blob_area , fill_blobs = fill_blobs )
326
- return PixelMap (ct_map .T ,px_p_um = 1 / output_um_p_px )
0 commit comments