1
1
from collections .abc import Iterable
2
- from itertools import chain
3
- from typing import TYPE_CHECKING , Any
2
+ from typing import TYPE_CHECKING , Any , Self , TypeVar
4
3
5
4
import matplotlib .pyplot as plt
6
5
import numpy as np
14
13
from matplotlib .patches import Patch
15
14
from matplotlib_scalebar .scalebar import ScaleBar
16
15
from mpl_toolkits import axes_grid1
17
- from numba import njit
18
16
from numpy .typing import NDArray
19
17
from scipy .sparse import coo_array , csc_array , csr_array
20
18
from skimage .feature import peak_local_max
21
- from typing_extensions import Self
22
19
23
20
from .._typealias import _Cmap , _Csx , _CsxArray , _Local_Max , _RangeTuple2D
24
21
from .._utils import _raise_module_load_error , _validate_n_threads , validate_threads
33
30
from ._utils import (
34
31
SCALEBAR_PARAMS ,
35
32
CosineCelltypeCallable ,
36
- _apply_color ,
37
33
_filter_blobs ,
38
34
_get_cell_dtype ,
39
35
_localmax_anndata ,
@@ -358,7 +354,6 @@ def load_local_maxima(
358
354
}
359
355
360
356
if self .total_mRNA_KDE is not None :
361
-
362
357
sdata_dict ["total_mRNA" ] = Image2DModel .parse (
363
358
np .atleast_3d (self .total_mRNA_KDE ).T , dims = ("c" , "y" , "x" )
364
359
)
@@ -416,7 +411,6 @@ def load_local_maxima(
416
411
return adata
417
412
418
413
def _load_KDE_maxima (self , genes : list [str ]) -> csc_array | csr_array :
419
-
420
414
assert self .local_maxima is not None
421
415
if self .kernel is None :
422
416
raise ValueError ("`kernel` must be set before running KDE" )
@@ -461,17 +455,13 @@ def filter_background(
461
455
If cell type-specific thresholds do not include all cell types or if
462
456
using cell type-specific thresholds before cell type assignment.
463
457
"""
458
+ T = TypeVar ("T" )
464
459
465
- @njit
466
460
def _map_celltype_to_value (
467
- ct_map : NDArray [np .integer ], thresholds : tuple [ float , ... ]
461
+ ct_map : NDArray [np .integer ], thresholds : dict [ T , float ], classes : list [ T ]
468
462
) -> NDArray [np .floating ]:
469
- values = np .zeros (shape = ct_map .shape , dtype = float )
470
- for i in range (ct_map .shape [0 ]):
471
- for j in range (ct_map .shape [1 ]):
472
- if ct_map [i , j ] >= 0 :
473
- values [i , j ] = thresholds [ct_map [i , j ]]
474
- return values
463
+ ordered_thresholds = np .array ([0 ] + [thresholds [ct ] for ct in classes ])
464
+ return np .take (ordered_thresholds , ct_map + 1 )
475
465
476
466
if self .total_mRNA_KDE is None :
477
467
raise ValueError (
@@ -485,8 +475,9 @@ def _map_celltype_to_value(
485
475
)
486
476
elif not all ([ct in min_norm .keys () for ct in self .celltypes ]):
487
477
raise ValueError ("'min_norm' does not contain all celltypes." )
488
- idx2threshold = tuple (min_norm [ct ] for ct in self .celltypes )
489
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
478
+ threshold = _map_celltype_to_value (
479
+ self .celltype_map , min_norm , self .celltypes
480
+ )
490
481
background = self .total_mRNA_KDE < threshold
491
482
else :
492
483
background = self .total_mRNA_KDE < min_norm
@@ -503,8 +494,9 @@ def _map_celltype_to_value(
503
494
)
504
495
elif not all ([ct in min_cosine .keys () for ct in self .celltypes ]):
505
496
raise ValueError ("'min_cosine' does not contain all celltypes." )
506
- idx2threshold = tuple (min_cosine [ct ] for ct in self .celltypes )
507
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
497
+ threshold = _map_celltype_to_value (
498
+ self .celltype_map , min_cosine , self .celltypes
499
+ )
508
500
background |= self .cosine_similarity <= threshold
509
501
else :
510
502
background |= self .cosine_similarity <= min_cosine
@@ -521,8 +513,9 @@ def _map_celltype_to_value(
521
513
)
522
514
elif not all ([ct in min_assignment .keys () for ct in self .celltypes ]):
523
515
raise ValueError ("'min_assignment' does not contain all celltypes." )
524
- idx2threshold = tuple (min_assignment [ct ] for ct in self .celltypes )
525
- threshold = _map_celltype_to_value (self .celltype_map , idx2threshold )
516
+ threshold = _map_celltype_to_value (
517
+ self .celltype_map , min_assignment , self .celltypes
518
+ )
526
519
background |= self .assignment_score <= threshold
527
520
else :
528
521
background |= self .assignment_score <= min_assignment
@@ -985,11 +978,11 @@ def plot_celltype_map(
985
978
color_map = [to_rgb (c ) if isinstance (c , str ) else c for c in cmap ]
986
979
987
980
# convert to uint8 to reduce memory of final image
988
- color_map_int = tuple (
989
- (np .array (c ) * 255 ).round ().astype (np .uint8 )
990
- for c in chain ([to_rgb (background )], color_map )
981
+ color_map_int = (
982
+ (np .array ([to_rgb (background )] + color_map ) * 255 ).round ().astype (np .uint8 )
991
983
)
992
- img = _apply_color (celltype_map .T , color_map_int )
984
+
985
+ img = np .take (color_map_int , celltype_map .T , axis = 0 )
993
986
994
987
if return_img :
995
988
return img
0 commit comments