1
1
from collections .abc import Iterable
2
- from itertools import chain
3
2
from pathlib import Path
4
- from typing import TYPE_CHECKING , Any , Self
3
+ from typing import TYPE_CHECKING , Any , Self , TypeVar
5
4
6
5
import matplotlib .pyplot as plt
7
6
import numpy as np
15
14
from matplotlib .patches import Patch
16
15
from matplotlib_scalebar .scalebar import ScaleBar
17
16
from mpl_toolkits import axes_grid1
18
- from numba import njit
19
17
from numpy .typing import NDArray
20
18
from scipy .sparse import coo_array , csc_array , csr_array
21
19
from skimage .feature import peak_local_max
33
31
from ._utils import (
34
32
SCALEBAR_PARAMS ,
35
33
CosineCelltypeCallable ,
36
- _apply_color ,
37
34
_filter_blobs ,
38
35
_get_cell_dtype ,
39
36
_localmax_anndata ,
@@ -358,7 +355,6 @@ def load_local_maxima(
358
355
}
359
356
360
357
if self .total_mRNA_KDE is not None :
361
-
362
358
sdata_dict ["total_mRNA" ] = Image2DModel .parse (
363
359
np .atleast_3d (self .total_mRNA_KDE ).T , dims = ("c" , "y" , "x" )
364
360
)
@@ -416,7 +412,6 @@ def load_local_maxima(
416
412
return adata
417
413
418
414
def _load_KDE_maxima (self , genes : list [str ]) -> csc_array | csr_array :
419
-
420
415
assert self .local_maxima is not None
421
416
if self .kernel is None :
422
417
raise ValueError ("`kernel` must be set before running KDE" )
@@ -461,17 +456,13 @@ def filter_background(
461
456
If cell type-specific thresholds do not include all cell types or if
462
457
using cell type-specific thresholds before cell type assignment.
463
458
"""
459
+ T = TypeVar ("T" )
464
460
465
- @njit
466
461
def _map_celltype_to_value (
467
- ct_map : NDArray [np .integer ], thresholds : tuple [ float , ... ]
462
+ ct_map : NDArray [np .integer ], thresholds : dict [ T , float ], classes : list [ T ]
468
463
) -> NDArray [np .floating ]:
469
- values = np .zeros (shape = ct_map .shape , dtype = float )
470
- for i in range (ct_map .shape [0 ]):
471
- for j in range (ct_map .shape [1 ]):
472
- if ct_map [i , j ] >= 0 :
473
- values [i , j ] = thresholds [ct_map [i , j ]]
474
- return values
464
+ ordered_thresholds = np .array ([0 ] + [thresholds [ct ] for ct in classes ])
465
+ return np .take (ordered_thresholds , ct_map + 1 )
475
466
476
467
if self .total_mRNA_KDE is None :
477
468
raise ValueError (
@@ -485,8 +476,9 @@ def _map_celltype_to_value(
485
476
)
486
477
elif not all ([ct in min_norm .keys () for ct in self .celltypes ]):
487
478
raise ValueError ("'min_norm' does not contain all celltypes." )
488
- idx2threshold = tuple (min_norm [ct ] for ct in self .celltypes )
489
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
479
+ threshold = _map_celltype_to_value (
480
+ self .celltype_map , min_norm , self .celltypes
481
+ )
490
482
background = self .total_mRNA_KDE < threshold
491
483
else :
492
484
background = self .total_mRNA_KDE < min_norm
@@ -503,8 +495,9 @@ def _map_celltype_to_value(
503
495
)
504
496
elif not all ([ct in min_cosine .keys () for ct in self .celltypes ]):
505
497
raise ValueError ("'min_cosine' does not contain all celltypes." )
506
- idx2threshold = tuple (min_cosine [ct ] for ct in self .celltypes )
507
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
498
+ threshold = _map_celltype_to_value (
499
+ self .celltype_map , min_cosine , self .celltypes
500
+ )
508
501
background |= self .cosine_similarity <= threshold
509
502
else :
510
503
background |= self .cosine_similarity <= min_cosine
@@ -521,8 +514,9 @@ def _map_celltype_to_value(
521
514
)
522
515
elif not all ([ct in min_assignment .keys () for ct in self .celltypes ]):
523
516
raise ValueError ("'min_assignment' does not contain all celltypes." )
524
- idx2threshold = tuple (min_assignment [ct ] for ct in self .celltypes )
525
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
517
+ threshold = _map_celltype_to_value (
518
+ self .celltype_map , min_assignment , self .celltypes
519
+ )
526
520
background |= self .assignment_score <= threshold
527
521
else :
528
522
background |= self .assignment_score <= min_assignment
@@ -1006,11 +1000,11 @@ def plot_celltype_map(
1006
1000
color_map = [to_rgb (c ) if isinstance (c , str ) else c for c in cmap ]
1007
1001
1008
1002
# convert to uint8 to reduce memory of final image
1009
- color_map_int = tuple (
1010
- (np .array (c ) * 255 ).round ().astype (np .uint8 )
1011
- for c in chain ([to_rgb (background )], color_map )
1003
+ color_map_int = (
1004
+ (np .array ([to_rgb (background )] + color_map ) * 255 ).round ().astype (np .uint8 )
1012
1005
)
1013
- img = _apply_color (celltype_map .T , color_map_int )
1006
+
1007
+ img = np .take (color_map_int , celltype_map .T , axis = 0 )
1014
1008
1015
1009
if return_img :
1016
1010
return img
0 commit comments