Skip to content

Commit 92cdca5

Browse files
ssam-denovo
1 parent 15a803f commit 92cdca5

File tree

3 files changed

+2834
-599
lines changed

3 files changed

+2834
-599
lines changed

plankton/utils.py

Lines changed: 110 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1+
from importlib.resources import path
12
from textwrap import fill
23
from threading import local
34

45
import numpy as np
56
# from numpy.ma.extras import _covhelper
6-
# from numpy.ma.core import dot
7+
# from numpy.ma.core import dot
78

89
from requests import patch
910
from scipy.ndimage import gaussian_filter, maximum_filter
1011
from plankton.pixelmaps import PixelMap
1112

12-
13+
from sklearn.neighbors import NearestNeighbors
14+
from sklearn.decomposition import FastICA
15+
from sklearn.cluster import KMeans
1316

1417
import matplotlib.pyplot as plt
1518

19+
from tqdm import tqdm_notebook
1620

17-
def determine_gains(stats1,stats2):
21+
def determine_gains(stats1, stats2):
1822

1923
norm_counts1 = stats1.counts/stats1.counts.sum()
2024
norm_counts2 = stats2.counts/stats2.counts.sum()
21-
25+
2226
return norm_counts1/norm_counts2
2327

2428

@@ -54,6 +58,7 @@ def hbar_compare(stat1, stat2, labels=None, text_display_threshold=0.02, c=None)
5458
if labels is not None:
5559
plt.xticks((0, 1), labels)
5660

61+
5762
def sorted_bar_compare(stat1, stat2, kwargs1={}, kwargs2={}):
5863
categories_1 = (stat1.index)
5964
counts_1 = (np.array(stat1.counts).flatten())
@@ -169,6 +174,7 @@ def sorted_bar_compare(stat1, stat2, kwargs1={}, kwargs2={}):
169174
# ax3.xaxis.set_label_position('top')
170175
# ax3.set_ylabel('log(count) spatial')
171176

177+
172178
def fill_celltypemaps(ct_map, fill_blobs=True, min_blob_area=0, filter_params={}, output_mask=None):
173179
"""
174180
Post-filter cell type maps created by `map_celltypes`.
@@ -191,25 +197,25 @@ def fill_celltypemaps(ct_map, fill_blobs=True, min_blob_area=0, filter_params={}
191197
"""
192198

193199
from skimage import measure
194-
195-
200+
196201
filtered_ctmaps = np.zeros_like(ct_map) - 1
197202

198203
for cidx in np.unique(ct_map):
199-
mask = ct_map==cidx
204+
mask = ct_map == cidx
200205
if min_blob_area > 0 or fill_blobs:
201206
blob_labels = measure.label(mask, background=0)
202207
for bp in measure.regionprops(blob_labels):
203208
if min_blob_area > 0 and bp.filled_area < min_blob_area:
204209
for c in bp.coords:
205-
mask[c[0], c[1],] = 0
206-
210+
mask[c[0], c[1], ] = 0
211+
207212
continue
208213
if fill_blobs and bp.area != bp.filled_area:
209214
minx, miny, maxx, maxy, = bp.bbox
210-
mask[minx:maxx, miny:maxy,] |= bp.filled_image
215+
mask[minx:maxx, miny:maxy, ] |= bp.filled_image
211216

212-
filtered_ctmaps[np.logical_and(mask == 1, np.logical_or(ct_map == -1, ct_map == cidx))] = cidx
217+
filtered_ctmaps[np.logical_and(mask == 1, np.logical_or(
218+
ct_map == -1, ct_map == cidx))] = cidx
213219

214220
return filtered_ctmaps
215221

@@ -240,87 +246,128 @@ def get_histograms(sdata, mins=None, maxs=None, category=None, resolution=5):
240246

241247
return histograms,
242248

243-
def crosscorr(x,y):
244-
x -= x.mean(1)[:,None]
245-
y -= y.mean(1)[:,None]
246249

247-
c = (np.dot(x, y.T)/x.shape[1] ).squeeze()
248-
return c/x.std(1)[:,None]/y.std(1)
250+
def crosscorr(x, y):
251+
x -= x.mean(1)[:, None]
252+
y -= y.mean(1)[:, None]
253+
c = (np.dot(x, y.T)/x.shape[1]).squeeze()
254+
255+
return np.nan_to_num(np.nan_to_num(c/np.array(x.std(1))[:, None])/np.array(y.std(1))[None, :])
249256

250257

251258
def ssam(sdata, signatures=None, adata_obs_label='celltype', kernel_bandwidth=2.5, output_um_p_px=5,
252-
patch_length=1000, threshold_exp=0.1, threshold_cor=0.1, background_value=-1,
253-
fill_blobs=True,min_blob_area=10):
259+
patch_length=1000, threshold_exp=0.1, threshold_cor=0.1, background_value=-1,
260+
fill_blobs=True, min_blob_area=10):
254261

255262
if (sdata.scanpy is not None) and (signatures is None):
256263
signatures = sdata.scanpy.generate_signatures(adata_obs_label)
257264

258-
kernel_bandwidth_px=kernel_bandwidth/output_um_p_px
265+
kernel_bandwidth_px = kernel_bandwidth/output_um_p_px
259266

260267
out_shape = (np.ceil(sdata.x.max()/output_um_p_px+kernel_bandwidth_px*3).astype(int),
261268
np.ceil(sdata.y.max()/output_um_p_px+kernel_bandwidth_px*3).astype(int))
262269

263270
ct_map = np.zeros((out_shape), dtype=int)+background_value
264271
vf_norm = np.zeros_like(ct_map)
265272

266-
range_x = np.floor(sdata.x.min()).astype(int),np.ceil(sdata.x.max()).astype(int)
267-
patch_delimiters_x = list(range(range_x[0],range_x[1],patch_length))+[range_x[1]]
273+
range_x = np.floor(sdata.x.min()).astype(
274+
int), np.ceil(sdata.x.max()).astype(int)
275+
patch_delimiters_x = list(
276+
range(range_x[0], range_x[1], patch_length))+[range_x[1]]
277+
278+
range_y = np.floor(sdata.y.min()).astype(
279+
int), np.ceil(sdata.y.max()).astype(int)
280+
patch_delimiters_y = list(
281+
range(range_y[0], range_y[1], patch_length))+[range_y[1]]
282+
283+
print(list(patch_delimiters_x), list(patch_delimiters_y))
284+
285+
with tqdm_notebook(total=(len(patch_delimiters_x)-1)*(len(patch_delimiters_y)-1)) as pbar:
286+
for i, x in enumerate(patch_delimiters_x[:-1]):
287+
for j, y in enumerate(patch_delimiters_y[:-1]):
288+
sdata_patch = sdata.raw().spatial[x:patch_delimiters_x[i+1],
289+
y:patch_delimiters_y[j+1]]
290+
291+
if not len(sdata_patch):
292+
break
293+
294+
mins = sdata_patch.coordinates.min(0).astype(int)
295+
maxs = sdata_patch.coordinates.max(
296+
0).astype(int)+kernel_bandwidth*3
297+
hists = get_histograms(
298+
sdata_patch, mins=mins, maxs=maxs, resolution=output_um_p_px)
299+
hists = np.concatenate(
300+
[gaussian_filter(h, kernel_bandwidth_px) for h in hists])
301+
302+
# print(hists.shape)
303+
304+
norm = np.sum(hists, axis=0)
305+
306+
vf_norm[(x+mins[0])//output_um_p_px:(x+mins[0])//output_um_p_px+norm.shape[0],
307+
(y+mins[1])//output_um_p_px:(y+mins[1])//output_um_p_px+norm.shape[1], ] = norm
308+
309+
mask = norm > threshold_exp
310+
311+
exps = np.zeros((len(sdata.genes), mask.sum()))
312+
313+
# print(exps.shape,signatures.shape)
314+
exps[sdata.stats.loc[sdata_patch.genes].gene_ids.values,
315+
:] = hists[:, mask]
316+
317+
local_ct_map = mask.astype(int)-1
318+
319+
corrs = crosscorr(exps.T, signatures)
320+
# print(corrs.min())
321+
corrs_winners = corrs.argmax(1)
322+
corrs_winners[corrs.max(1) < threshold_cor] = background_value
323+
local_ct_map[mask] = corrs_winners
268324

269-
range_y = np.floor(sdata.y.min()).astype(int),np.ceil(sdata.y.max()).astype(int)
270-
patch_delimiters_y = list(range(range_y[0],range_y[1],patch_length))+[range_y[1]]
325+
# idcs=sdata_patch.index
271326

272-
print(list(patch_delimiters_x),list(patch_delimiters_y))
327+
x_ = int(max(0, (x+mins[0])-kernel_bandwidth_px*3))//output_um_p_px
328+
y_ = int(max(0, (y+mins[1])-kernel_bandwidth_px*3))//output_um_p_px
273329

274-
for i,x in enumerate(patch_delimiters_x[:-1]):
275-
for j,y in enumerate(patch_delimiters_y[:-1]):
276-
sdata_patch = sdata.raw().spatial[x:patch_delimiters_x[i+1],
277-
y:patch_delimiters_y[j+1]]
330+
ct_map[x_:x_+local_ct_map.shape[0],
331+
y_:y_+local_ct_map.shape[1]] = local_ct_map
278332

279-
if not len(sdata_patch): break
280-
281-
mins = sdata_patch.coordinates.min(0).astype(int)
282-
maxs = sdata_patch.coordinates.max(0).astype(int)+kernel_bandwidth*3
283-
hists = get_histograms(sdata_patch,mins=mins,maxs=maxs, resolution=output_um_p_px)
284-
hists = np.concatenate([gaussian_filter(h, kernel_bandwidth_px) for h in hists])
333+
pbar.update(1)
285334

286-
print(hists.shape)
335+
ct_map = fill_celltypemaps(
336+
ct_map, min_blob_area=min_blob_area, fill_blobs=fill_blobs)
337+
return PixelMap(ct_map.T, px_p_um=1/output_um_p_px)
287338

288-
norm = np.sum(hists, axis=0)
289339

290-
vf_norm[(x+mins[0])//output_um_p_px:(x+mins[0])//output_um_p_px+norm.shape[0],
291-
(y+mins[1])//output_um_p_px:(y+mins[1])//output_um_p_px+norm.shape[1],]=norm
340+
def localmax_sampling(sdata, n_clusters=10, min_distance=3, bandwidth=4):
341+
n_bins = np.array(sdata.spatial.shape)
292342

293-
mask = norm > threshold_exp
343+
vf = (gaussian_filter(np.histogram2d(
344+
*sdata.coordinates.T, bins=n_bins)[0], 2))
345+
localmaxs = np.where((maximum_filter(vf, min_distance) == vf) & (vf > 0.2))
346+
knn = NearestNeighbors(n_neighbors=150)
347+
knn.fit(sdata.coordinates)
348+
dists, nbrs = knn.kneighbors(np.array(localmaxs).T)
349+
neighbor_types = np.array(sdata.gene_ids)[nbrs]
294350

295-
exps = np.zeros((len(sdata.genes),mask.sum()))
351+
counts = np.zeros((dists.shape[0], len(sdata.genes)))
296352

297-
print(exps.shape,signatures.shape)
298-
exps[sdata.stats.loc[sdata_patch.genes].gene_ids.values,:]=hists[:, mask]
353+
bandwidth = bandwidth
354+
def kernel(x): return np.exp(-x**2/(2*bandwidth**2))
299355

300-
# signatures -= signatures.mean(0)
301-
# signatures /= signatures.std(0)
356+
for i in range(0, dists.shape[1]):
357+
counts[np.arange(dists.shape[0]), neighbor_types[:, i]
358+
] += kernel(dists[:, i])
302359

303-
# exps -= exps.mean(0)
304-
# exps -= exps.std(0)
360+
assert (all(counts.sum(1)) > 0)
305361

306-
local_ct_map = mask.astype(int)-1
307-
# corrs = np.inner(exps.T, signatures)
362+
counts=counts/counts.sum(1)[:,None]
308363

309-
corrs = crosscorr(exps.T,signatures)
310-
print(corrs.min())
311-
corrs_winners = corrs.argmax(1)
312-
corrs_winners[corrs.max(1)<threshold_cor]=background_value
313-
local_ct_map[mask] = corrs_winners
364+
ica = FastICA(n_components=30)
365+
facs = ica.fit_transform(counts)
314366

315-
# idcs=sdata_patch.index
367+
km = KMeans(
368+
n_clusters=n_clusters,
369+
max_iter=100,
370+
n_init=1,).fit(counts[:])
316371

372+
return km.cluster_centers_
317373

318-
x_=int(max(0,(x+mins[0])-kernel_bandwidth_px*3))//output_um_p_px
319-
y_=int(max(0,(y+mins[1])-kernel_bandwidth_px*3))//output_um_p_px
320-
321-
322-
ct_map[x_:x_+local_ct_map.shape[0],
323-
y_:y_+local_ct_map.shape[1]]=local_ct_map
324-
325-
ct_map = fill_celltypemaps(ct_map,min_blob_area=min_blob_area, fill_blobs=fill_blobs)
326-
return PixelMap(ct_map.T,px_p_um=1/output_um_p_px)

0 commit comments

Comments
 (0)