Skip to content

Commit 96b5b1a

Browse files
committed
add StarDist+RCTD and 6 benchmarking datasets
1 parent 9c0c8a9 commit 96b5b1a

File tree

9 files changed

+17108
-12
lines changed

9 files changed

+17108
-12
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,12 @@ We provide source codes for reproducing the SpatialScope analysis in the main te
2929

3030
All relevent materials involved in the reproducing codes are availabel from [here](https://drive.google.com/drive/folders/1PXv_brtr-tXshBVEd_HSPIagjX9oF7Kg?usp=sharing)
3131

32-
+ [Benchmarking](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-MERFISH.ipynb)
32+
+ [Benchmarking Dataset 1](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_1.ipynb)
33+
+ [Benchmarking Dataset 2](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_2.ipynb)
34+
+ [Benchmarking Dataset 3](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_3.ipynb)
35+
+ [Benchmarking Dataset 4](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_4.ipynb)
36+
+ [Benchmarking Dataset 5](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_5.ipynb)
37+
+ [Benchmarking Dataset 6](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Benchmarking-Dataset_6.ipynb)
3338
+ [Human Heart (Visium, a single slice)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Human-Heart.ipynb)
3439
+ [Mouse Brain (Visium, 3D alignment of multiple slices)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Mouse-Brain.ipynb)
3540
+ [Mouse Cerebellum (Slideseq-V2)](https://github.com/YangLabHKUST/SpatialScope/blob/master/demos/Mouse-Cerebellum-Slideseq.ipynb)

compared_methods/SDRCTD.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import anndata
2+
import numpy as np
3+
import pandas as pd
4+
import sys
5+
import pickle
6+
import os
7+
import copy
8+
import argparse
9+
from sklearn.model_selection import KFold
10+
from sklearn.metrics import mean_squared_error
11+
import pandas as pd
12+
import matplotlib.pyplot as plt
13+
import scanpy as sc
14+
import warnings
15+
warnings.filterwarnings('ignore')
16+
import seaborn as sns
17+
18+
from SDRCTD_utils import *
19+
20+
21+
22+
23+
class SDRCTD:
24+
def __init__(self,tissue,out_dir,RCTD_results_dir,RCTD_results_name, ST_Data, SC_Data, cell_class_column = 'cell_type', hs_ST = True):
25+
self.tissue = tissue
26+
self.out_dir = out_dir
27+
self.RCTD_results_dir = RCTD_results_dir
28+
self.RCTD_results_name = RCTD_results_name
29+
self.ST_Data = ST_Data
30+
self.SC_Data = SC_Data
31+
self.cell_class_column = cell_class_column
32+
self.hs_ST = hs_ST
33+
34+
if not os.path.exists(out_dir):
35+
os.mkdir(out_dir)
36+
if not os.path.exists(os.path.join(out_dir,tissue)):
37+
os.mkdir(os.path.join(out_dir,tissue))
38+
39+
self.out_dir = os.path.join(out_dir,tissue)
40+
loggings = configure_logging(os.path.join(self.out_dir,'logs'))
41+
self.loggings = loggings
42+
43+
self.LoadRCTDresults()
44+
if SC_Data is not None:
45+
self.LoadSCData()
46+
47+
48+
def LoadRCTDresults(self):
49+
with open(os.path.join(self.RCTD_results_dir, self.RCTD_results_name + '.pickle'), 'rb') as handle:
50+
RCTD_results = pickle.load(handle)
51+
52+
if self.hs_ST:
53+
try:
54+
weights = RCTD_results['results']['weights']
55+
except:
56+
weights = RCTD_results['results']
57+
else:
58+
weights = RCTD_results['results']
59+
60+
61+
self.weights = (weights / np.array(weights.sum(1))[:, None])
62+
63+
def single_cell_type_assignment(self, cell_num_column = 'cell_count', VisiumCellsPlot = True):
64+
seged_sp_adata = sc.read(self.ST_Data) #ST_Data already complete nuclei segmentation with StarDist. 'cell_locations' already in uns
65+
66+
mat = self.weights.values
67+
cell_nums = np.array(seged_sp_adata.obs[cell_num_column])
68+
69+
cell_counts = distribute_cells(mat, cell_nums)
70+
cell_types = self.weights.columns
71+
cell_type_list = assign_cell_type(cell_counts, cell_types)
72+
seged_sp_adata.uns['cell_locations']['SDRCTD_cell_type'] = cell_type_list
73+
74+
self.cell_type_list = cell_type_list
75+
self.seged_sp_adata = seged_sp_adata
76+
77+
seged_sp_adata.uns['RCTD_weights'] = self.weights
78+
seged_sp_adata.write(os.path.join(self.out_dir, 'single_cell_type_label_bySDRCTD.h5ad'))
79+
80+
# plot results
81+
if self.hs_ST or not VisiumCellsPlot:
82+
fig, ax = plt.subplots(figsize=(10,8.5),dpi=100)
83+
sns.scatterplot(data=seged_sp_adata.uns['cell_locations'], x="x",y="y",s=10,hue='SDRCTD_cell_type',palette='tab20',legend=True)
84+
plt.axis('off')
85+
plt.legend(bbox_to_anchor=(0.97, .98),framealpha=0)
86+
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png'))
87+
plt.close()
88+
89+
elif VisiumCellsPlot:
90+
if seged_sp_adata.obsm['spatial'].shape[1] == 2:
91+
fig, ax = plt.subplots(1,1,figsize=(14, 8),dpi=200)
92+
PlotVisiumCells(seged_sp_adata,"SDRCTD_cell_type",size=0.4,alpha_img=0.4,lw=0.4,palette='tab20',ax=ax)
93+
plt.savefig(os.path.join(self.out_dir, 'SDRCTD_estemated_ct_label.png'))
94+
plt.close()
95+
96+
def cell_type_mean_assignment(self):
97+
# cell type mean as decomposed cell gene expression
98+
ref_df = pd.DataFrame([[ct, i]for i, ct in enumerate(self.sc_data_process_marker.obs[self.cell_class_column].astype('category').cat.categories)], columns = ['cell_type', 'cell_type_code'])
99+
ref_df.index = ref_df.cell_type
100+
ref_df = ref_df.iloc[:,1:]
101+
102+
x_decom = self.mu[ref_df.loc[np.array(self.cell_type_list)].cell_type_code.tolist()]
103+
x_decom_adata = anndata.AnnData(X = x_decom.copy(), obs = self.seged_sp_adata.uns['cell_locations'].copy(), var = self.sc_data_process_marker.var)
104+
x_decom_adata.write(os.path.join(self.out_dir, 'cell_type_mean_bySDRCTD.h5ad'))
105+
106+
107+
def LoadSCData(self):
108+
# load sc data
109+
sc_data_process = anndata.read_h5ad(self.SC_Data)
110+
if 'Marker' in sc_data_process.var.columns:
111+
sc_data_process_marker = sc_data_process[:,sc_data_process.var['Marker']]
112+
else:
113+
sc_data_process_marker = sc_data_process
114+
115+
if sc_data_process_marker.X.max() <= 30:
116+
self.loggings.info(f'Maximum value: {sc_data_process_marker.X.max()}, need to run exp')
117+
try:
118+
sc_data_process_marker.X = np.exp(sc_data_process_marker.X) - 1
119+
except:
120+
sc_data_process_marker.X = np.exp(sc_data_process_marker.X.toarray()) - 1
121+
122+
123+
cell_type_array = np.array(sc_data_process_marker.obs[self.cell_class_column])
124+
cell_type_class = np.unique(cell_type_array)
125+
df_category = sc_data_process_marker.obs[[self.cell_class_column]].astype('category').apply(lambda x: x.cat.codes)
126+
127+
# parameters: mean and cell type index
128+
cell_type_array_code = np.array(df_category[self.cell_class_column])
129+
try:
130+
data = sc_data_process_marker.X.toarray()
131+
except:
132+
data = sc_data_process_marker.X
133+
134+
n, d = data.shape
135+
q = cell_type_class.shape[0]
136+
self.loggings.info(f'scRNA-seq data shape: {data.shape}')
137+
self.loggings.info(f'scRNA-seq cell class number: {q}')
138+
139+
mu = np.zeros((q, d))
140+
for k in range(q):
141+
mu[k] = data[cell_type_array_code == k].mean(0).squeeze()
142+
self.mu = mu
143+
self.sc_data_process_marker = sc_data_process_marker
144+
145+
146+
147+
148+
149+
150+
151+
if __name__ == "__main__":
152+
HEADER = """
153+
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
154+
<>
155+
<> StarDist + RCTD
156+
<> Version: %s
157+
<> MIT License
158+
<>
159+
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
160+
<> Software-related correspondence: %s or %s
161+
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
162+
<> Visium data example:
163+
python <install path>/src/Cell_Type_Identification.py \\
164+
--cell_class_column cell_type \\
165+
--tissue heart \\
166+
--out_dir ./output \\
167+
--ST_Data ./output/heart/sp_adata_ns.h5ad \\
168+
--SC_Data ./Ckpts_scRefs/Heart_D2/Ref_Heart_sanger_D2.h5ad
169+
<><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
170+
"""
171+
def str2bool(v):
172+
if isinstance(v, bool):
173+
return v
174+
if v.lower() in ("yes", "true", "t", "y", "1"):
175+
return True
176+
elif v.lower() in ("no", "false", "f", "n", "0"):
177+
return False
178+
else:
179+
raise argparse.ArgumentTypeError("Boolean value expected.")
180+
181+
parser = argparse.ArgumentParser(description='simulation sour_sep')
182+
parser.add_argument('--out_dir', type=str, help='output path', default=None)
183+
parser.add_argument('--RCTD_results_dir', type=str, help='RCTD results path', default=None)
184+
parser.add_argument('--RCTD_results_name', type=str, help='RCTD results file\'s name', default='InitProp')
185+
parser.add_argument('--ST_Data', type=str, help='ST data path', default=None)
186+
parser.add_argument('--SC_Data', type=str, help='single cell reference data path', default=None)
187+
parser.add_argument('--cell_class_column', type=str, help='input cell class label column in scRef file', default = 'cell_type')
188+
parser.add_argument('--cell_num_column', type=str, help='cell number column in spatial file', default = 'cell_count')
189+
parser.add_argument('--hs_ST', action="store_true", help='high resolution ST data such as Slideseq, DBiT-seq, and HDST, MERFISH etc.')
190+
parser.add_argument("--VisiumCellsPlot", type=str2bool, const=True, default=True, nargs="?", help="whether to plot in VisiumCells mode or just scatter plot")
191+
args = parser.parse_args()
192+
193+
args.tissue = 'SDRCTD_results'
194+
if not os.path.exists(args.out_dir):
195+
os.mkdir(args.out_dir)
196+
if not os.path.exists(os.path.join(args.out_dir,args.tissue)):
197+
os.mkdir(os.path.join(args.out_dir,args.tissue))
198+
if args.RCTD_results_dir is None:
199+
args.RCTD_results_dir = args.out_dir
200+
201+
202+
sdr = SDRCTD(args.tissue,args.out_dir, args.RCTD_results_dir, args.RCTD_results_name, args.ST_Data, args.SC_Data, cell_class_column = args.cell_class_column, hs_ST = args.hs_ST)
203+
sdr.single_cell_type_assignment(cell_num_column = args.cell_num_column, VisiumCellsPlot = args.VisiumCellsPlot)
204+
if args.SC_Data is not None:
205+
sdr.cell_type_mean_assignment()

compared_methods/SDRCTD_utils.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
import scanpy as sc
3+
# import squidpy as sq
4+
import numpy as np
5+
import pandas as pd
6+
import pathlib
7+
import matplotlib.pyplot as plt
8+
import matplotlib as mpl
9+
# import skimage
10+
import seaborn as sns
11+
from itertools import chain
12+
# from stardist.models import StarDist2D
13+
from csbdeep.utils import normalize
14+
from anndata import AnnData
15+
from scipy.spatial.distance import pdist
16+
import logging
17+
import sys
18+
from sklearn.metrics.pairwise import cosine_similarity
19+
20+
def PlotVisiumCells(adata,annotation_list,size=0.8,alpha_img=0.3,lw=1,subset=None,palette='tab20',show_circle = True, legend = True, ax=None,**kwargs):
21+
merged_df = adata.uns['cell_locations'].copy()
22+
test = sc.AnnData(np.zeros(merged_df.shape), obs=merged_df)
23+
test.obsm['spatial'] = merged_df[["x", "y"]].to_numpy().copy()
24+
test.uns = adata.uns
25+
26+
if subset is not None:
27+
#test = test[test.obs[annotation_list].isin(subset)]
28+
test.obs.loc[~test.obs[annotation_list].isin(subset),annotation_list] = None
29+
30+
sc.pl.spatial(
31+
test,
32+
color=annotation_list,
33+
size=size,
34+
frameon=False,
35+
alpha_img=alpha_img,
36+
show=False,
37+
palette=palette,
38+
na_in_legend=False,
39+
ax=ax,title='',sort_order=True,**kwargs
40+
)
41+
if show_circle:
42+
sf = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['tissue_hires_scalef']
43+
spot_radius = adata.uns['spatial'][list(adata.uns['spatial'].keys())[0]]['scalefactors']['spot_diameter_fullres']/2
44+
for sloc in adata.obsm['spatial']:
45+
rect = mpl.patches.Circle(
46+
(sloc[0] * sf, sloc[1] * sf),
47+
spot_radius * sf,
48+
ec="grey",
49+
lw=lw,
50+
fill=False
51+
)
52+
ax.add_patch(rect)
53+
ax.axes.xaxis.label.set_visible(False)
54+
ax.axes.yaxis.label.set_visible(False)
55+
56+
if not legend:
57+
ax.get_legend().remove()
58+
59+
# make frame visible
60+
for _, spine in ax.spines.items():
61+
spine.set_visible(True)
62+
63+
64+
65+
def assign_cell_type(cell_counts, cell_types):
66+
cell_type_list = []
67+
for i in range(cell_counts.shape[0]):
68+
cell_count = cell_counts[i]
69+
idx = np.where(cell_count > 0)[0]
70+
cell_type_list_row = [[cell_types[idx][_]] * cell_count[idx[_]] for _ in range(idx.shape[0])]
71+
cell_type_list_row = np.array([item for sublist in cell_type_list_row for item in sublist])
72+
np.random.shuffle(cell_type_list_row)
73+
cell_type_list = cell_type_list + list(cell_type_list_row)
74+
75+
return cell_type_list
76+
77+
def distribute_cells(mat, cell_nums): # mat: spots * cell_type; cell_nums: spots * 1
78+
cell_nums = cell_nums.astype(int)
79+
cell_nums_original = cell_nums.copy()
80+
81+
mat[np.absolute(mat) < 1e-3] = 0
82+
cell_counts = np.zeros(mat.shape).astype(int)
83+
assert not np.any(cell_nums < 0)
84+
assert not np.any(mat < 0)
85+
86+
mat = mat * cell_nums[:, None]
87+
cell_num_dist = np.floor(mat).astype(int)
88+
cell_counts = cell_counts + cell_num_dist
89+
cell_nums_remain = cell_nums - cell_num_dist.sum(1)
90+
mat_remain = mat - cell_num_dist
91+
92+
assert not np.any(cell_nums_remain < 0)
93+
assert not np.any(mat_remain < 0)
94+
95+
mat = mat_remain
96+
cell_nums = cell_nums_remain
97+
98+
while(np.any(cell_nums_remain > 0)):
99+
mat[mat.argsort()[:, ::-1].argsort() >= cell_nums[:, None]] = 0
100+
mat = np.divide(mat, mat.sum(1)[:,None], out=np.zeros_like(mat), where=mat.sum(1)[:,None]!=0)
101+
mat = mat * cell_nums[:, None]
102+
cell_num_dist = np.floor(mat).astype(int)
103+
cell_counts = cell_counts + cell_num_dist
104+
cell_nums_remain = cell_nums - cell_num_dist.sum(1)
105+
mat_remain = mat - cell_num_dist
106+
107+
assert not np.any(cell_nums_remain < 0)
108+
assert not np.any(mat_remain < 0)
109+
110+
mat = mat_remain
111+
cell_nums = cell_nums_remain
112+
113+
assert np.array_equal(cell_counts.sum(1), cell_nums_original)
114+
115+
return cell_counts
116+
117+
118+
119+
120+
def configure_logging(logger_name):
121+
LOG_LEVEL = logging.DEBUG
122+
log_filename = logger_name+'.log'
123+
importer_logger = logging.getLogger('importer_logger')
124+
importer_logger.setLevel(LOG_LEVEL)
125+
formatter = logging.Formatter('%(asctime)s : %(levelname)s : %(message)s')
126+
127+
fh = logging.FileHandler(filename=log_filename)
128+
fh.setLevel(LOG_LEVEL)
129+
fh.setFormatter(formatter)
130+
importer_logger.addHandler(fh)
131+
132+
sh = logging.StreamHandler(sys.stdout)
133+
sh.setLevel(LOG_LEVEL)
134+
sh.setFormatter(formatter)
135+
importer_logger.addHandler(sh)
136+
return importer_logger
137+
138+
139+

0 commit comments

Comments
 (0)