Skip to content

Commit 09b747d

Browse files
Some additional refractorings, removed unused vars in ovrlpy/_utils.py, resolved double localmax search.
1 parent 533325a commit 09b747d

File tree

2 files changed

+25
-51
lines changed

2 files changed

+25
-51
lines changed

ovrlpy/_ovrlp.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_compute_divergence_patched,
2121
_create_histogram,
2222
_create_knn_graph,
23-
_determine_localmax,
23+
_determine_localmax_and_sample,
2424
_fill_color_axes,
2525
_get_knn_expression,
2626
_get_spatial_subsample_mask,
@@ -274,7 +274,7 @@ def get_pseudocell_locations(
274274
df, genes=genes, min_expression=min_expression, KDE_bandwidth=KDE_bandwidth
275275
)
276276

277-
pseudocell_locations_x, pseudocells_y, _ = _determine_localmax(
277+
pseudocell_locations_x, pseudocells_y, _ = _determine_localmax_and_sample(
278278
hist, min_distance=min_distance, min_expression=min_expression
279279
)
280280

@@ -443,7 +443,7 @@ def detect_doublets(
443443
if integrity_sigma is not None:
444444
integrity_map = gaussian_filter(integrity_map, integrity_sigma)
445445

446-
dist_x, dist_y, dist_t = _determine_localmax(
446+
dist_x, dist_y, dist_t = _determine_localmax_and_sample(
447447
(1 - integrity_map) * (signal_map > minimum_signal_strength),
448448
min_distance=min_distance,
449449
min_expression=integrity_threshold,
@@ -787,7 +787,6 @@ def transform(self, coordinate_df: pd.DataFrame):
787787
self.pca_2d,
788788
embedder_2d=self.embedder_2d,
789789
embedder_3d=self.embedder_3d,
790-
colors_min_max=self.colors_min_max,
791790
)
792791
subsample_embedding_color, _ = _fill_color_axes(
793792
subsample_embedding_color, self.pca_3d
@@ -825,8 +824,8 @@ def pseudocell_df(self) -> pd.DataFrame:
825824

826825
def plot_region_of_interest(
827826
self,
828-
subsample,
829-
subsample_embedding_color,
827+
subsample: pd.DataFrame,
828+
subsample_embedding_color: np.ndarray,
830829
x: float = None,
831830
y: float = None,
832831
window_size: int = None,
@@ -839,7 +838,7 @@ def plot_region_of_interest(
839838
----------
840839
subsample : pandas.DataFrame
841840
A dataframe of molecule coordinates and gene assignments.
842-
subsample_embedding_color : Optional[pandas.DataFrame]
841+
subsample_embedding_color : pandas.DataFrame
843842
A list of rgb values for each molecule.
844843
x : float
845844
Center x-coordinate for the region-of-interest.

ovrlpy/_utils.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from concurrent.futures import ThreadPoolExecutor, as_completed
22

3+
import matplotlib.patheffects as PathEffects
4+
35
# create circular kernel:
46
# draw outlines around artist:
5-
import matplotlib.patheffects as PathEffects
67
import matplotlib.pyplot as plt
78
import numpy as np
89
import pandas as pd
910
import tqdm
10-
from scipy.ndimage import gaussian_filter, maximum_filter
11+
from scipy.ndimage import gaussian_filter
1112
from sklearn.decomposition import PCA
1213
from sklearn.neighbors import NearestNeighbors
1314

14-
from ._ssam2 import kde_2d
15+
from ._ssam2 import find_local_maxima, kde_2d
1516

1617

17-
def _draw_outline(ax, artist, lw=2, color="black"):
18+
def _draw_outline(artist, lw=2, color="black"):
19+
"Draws outlines around the (text) artists for better legibility."
1820
_ = artist.set_path_effects(
1921
[PathEffects.withStroke(linewidth=lw, foreground=color), PathEffects.Normal()]
2022
)
@@ -43,33 +45,12 @@ def _plot_scalebar(
4345
)
4446

4547
if edge_color is not None:
46-
_draw_outline(ax, plot_artist[0], lw=5, color=edge_color)
47-
_draw_outline(ax, text_artist, lw=5, color=edge_color)
48+
_draw_outline(plot_artist[0], lw=5, color=edge_color)
49+
_draw_outline(text_artist, lw=5, color=edge_color)
4850

4951
return plot_artist, text_artist
5052

5153

52-
def _create_circular_kernel(r):
53-
"""
54-
Creates a circular kernel of radius r.
55-
56-
Parameters
57-
----------
58-
r : int
59-
The radius of the kernel.
60-
61-
Returns
62-
-------
63-
kernel : np.array
64-
A 2d array of the circular kernel.
65-
66-
"""
67-
68-
span = np.linspace(-1, 1, r * 2)
69-
X, Y = np.meshgrid(span, span)
70-
return (X**2 + Y**2) ** 0.5 <= 1
71-
72-
7354
def _get_kl_divergence(p, q):
7455
# mask = (p!=0) * (q!=0)
7556
output = np.zeros(p.shape)
@@ -78,7 +59,7 @@ def _get_kl_divergence(p, q):
7859
return output
7960

8061

81-
def _determine_localmax(distribution, min_distance=3, min_expression=5):
62+
def _determine_localmax_and_sample(distribution, min_distance=3, min_expression=5):
8263
"""
8364
Returns a list of local maxima in a kde of the data frame.
8465
@@ -99,12 +80,8 @@ def _determine_localmax(distribution, min_distance=3, min_expression=5):
9980
A list of y coordinates of local maxima.
10081
10182
"""
102-
localmax_kernel = _create_circular_kernel(min_distance)
103-
localmax_projection = distribution == maximum_filter(
104-
distribution, footprint=localmax_kernel
105-
)
10683

107-
rois_x, rois_y = np.where((distribution > min_expression) & localmax_projection)
84+
rois_x, rois_y = find_local_maxima(distribution, min_distance, min_expression)
10885

10986
return rois_x, rois_y, distribution[rois_x, rois_y]
11087

@@ -148,15 +125,15 @@ def _min_to_max(arr, arr_min=None, arr_max=None):
148125

149126
# define a function that fits expression data to into the umap embeddings:
150127
def _transform_embeddings(
151-
expression, pca, embedder_2d, embedder_3d, colors_min_max=[None, None]
128+
expression,
129+
pca,
130+
embedder_2d,
131+
embedder_3d,
152132
):
153133
factors = pca.transform(expression)
154134

155135
embedding = embedder_2d.transform(factors)
156136
embedding_color = embedder_3d.transform(factors)
157-
# embedding_color = embedder_3d.transform(embedding)
158-
159-
# embedding_color = _min_to_max(embedding_color,colors_min_max[0],colors_min_max[1])
160137

161138
return embedding, embedding_color
162139

@@ -192,12 +169,12 @@ def _plot_embeddings(
192169
)
193170

194171
text_artists = []
195-
for i in range(len(celltypes)):
172+
for i, celltype in enumerate(celltypes):
196173
if not np.isnan(celltype_centers[i, 0]):
197174
t = ax.text(
198175
np.nan_to_num((celltype_centers[i, 0])),
199176
np.nan_to_num(celltype_centers[i, 1]),
200-
celltypes[i],
177+
celltype,
201178
color="k",
202179
fontsize=12,
203180
)
@@ -364,6 +341,9 @@ def _compute_divergence_embedded(
364341
metric="cosine_similarity",
365342
pca_divergence=0.8,
366343
):
344+
"""This is a legacy function, replaced by _compute_divergence_patched. It contains other similarity measures than cosine similarity.
345+
To be integrated into the patch-based divergence computation later.
346+
"""
367347
signal = _create_histogram(
368348
df,
369349
genes,
@@ -381,9 +361,6 @@ def _compute_divergence_embedded(
381361
df_top = df[df.z_delim < df.z]
382362
df_bot = df[df.z_delim > df.z]
383363

384-
# dr_bottom = np.zeros((df_bottom.shape[0],df_bottom.shape[1], pca.components_.shape[0]))
385-
# dr_top = np.zeros((df_bottom.shape[0],df_bottom.shape[1], pca.components_.shape[0]))
386-
387364
hists_top = np.zeros((mask.sum(), pca.components_.shape[0]))
388365
hists_bot = np.zeros((mask.sum(), pca.components_.shape[0]))
389366

@@ -481,8 +458,6 @@ def pearson_cross_correlation(a, b):
481458

482459

483460
def _compute_embedding_vectors(subset_df, signal_mask, factor):
484-
# for i,g in tqdm.tqdm(enumerate(genes),total=len(genes)):
485-
486461
if len(subset_df) < 2:
487462
return None, None
488463

0 commit comments

Comments
 (0)