1+ import anndata
2+ import numpy as np
3+ import pandas as pd
4+ import sys
5+ import pickle
6+ import os
7+ import copy
8+ import argparse
9+ from sklearn .model_selection import KFold
10+ from sklearn .metrics import mean_squared_error
11+ import pandas as pd
12+ import matplotlib .pyplot as plt
13+ import scanpy as sc
14+ import warnings
15+ warnings .filterwarnings ('ignore' )
16+ import seaborn as sns
17+
18+ from SDRCTD_utils import *
19+
20+
21+
22+
23+ class SDRCTD :
24+ def __init__ (self ,tissue ,out_dir ,RCTD_results_dir ,RCTD_results_name , ST_Data , SC_Data , cell_class_column = 'cell_type' , hs_ST = True ):
25+ self .tissue = tissue
26+ self .out_dir = out_dir
27+ self .RCTD_results_dir = RCTD_results_dir
28+ self .RCTD_results_name = RCTD_results_name
29+ self .ST_Data = ST_Data
30+ self .SC_Data = SC_Data
31+ self .cell_class_column = cell_class_column
32+ self .hs_ST = hs_ST
33+
34+ if not os .path .exists (out_dir ):
35+ os .mkdir (out_dir )
36+ if not os .path .exists (os .path .join (out_dir ,tissue )):
37+ os .mkdir (os .path .join (out_dir ,tissue ))
38+
39+ self .out_dir = os .path .join (out_dir ,tissue )
40+ loggings = configure_logging (os .path .join (self .out_dir ,'logs' ))
41+ self .loggings = loggings
42+
43+ self .LoadRCTDresults ()
44+ if SC_Data is not None :
45+ self .LoadSCData ()
46+
47+
48+ def LoadRCTDresults (self ):
49+ with open (os .path .join (self .RCTD_results_dir , self .RCTD_results_name + '.pickle' ), 'rb' ) as handle :
50+ RCTD_results = pickle .load (handle )
51+
52+ if self .hs_ST :
53+ try :
54+ weights = RCTD_results ['results' ]['weights' ]
55+ except :
56+ weights = RCTD_results ['results' ]
57+ else :
58+ weights = RCTD_results ['results' ]
59+
60+
61+ self .weights = (weights / np .array (weights .sum (1 ))[:, None ])
62+
63+ def single_cell_type_assignment (self , cell_num_column = 'cell_count' , VisiumCellsPlot = True ):
64+ seged_sp_adata = sc .read (self .ST_Data ) #ST_Data already complete nuclei segmentation with StarDist. 'cell_locations' already in uns
65+
66+ mat = self .weights .values
67+ cell_nums = np .array (seged_sp_adata .obs [cell_num_column ])
68+
69+ cell_counts = distribute_cells (mat , cell_nums )
70+ cell_types = self .weights .columns
71+ cell_type_list = assign_cell_type (cell_counts , cell_types )
72+ seged_sp_adata .uns ['cell_locations' ]['SDRCTD_cell_type' ] = cell_type_list
73+
74+ self .cell_type_list = cell_type_list
75+ self .seged_sp_adata = seged_sp_adata
76+
77+ seged_sp_adata .uns ['RCTD_weights' ] = self .weights
78+ seged_sp_adata .write (os .path .join (self .out_dir , 'single_cell_type_label_bySDRCTD.h5ad' ))
79+
80+ # plot results
81+ if self .hs_ST or not VisiumCellsPlot :
82+ fig , ax = plt .subplots (figsize = (10 ,8.5 ),dpi = 100 )
83+ sns .scatterplot (data = seged_sp_adata .uns ['cell_locations' ], x = "x" ,y = "y" ,s = 10 ,hue = 'SDRCTD_cell_type' ,palette = 'tab20' ,legend = True )
84+ plt .axis ('off' )
85+ plt .legend (bbox_to_anchor = (0.97 , .98 ),framealpha = 0 )
86+ plt .savefig (os .path .join (self .out_dir , 'SDRCTD_estemated_ct_label.png' ))
87+ plt .close ()
88+
89+ elif VisiumCellsPlot :
90+ if seged_sp_adata .obsm ['spatial' ].shape [1 ] == 2 :
91+ fig , ax = plt .subplots (1 ,1 ,figsize = (14 , 8 ),dpi = 200 )
92+ PlotVisiumCells (seged_sp_adata ,"SDRCTD_cell_type" ,size = 0.4 ,alpha_img = 0.4 ,lw = 0.4 ,palette = 'tab20' ,ax = ax )
93+ plt .savefig (os .path .join (self .out_dir , 'SDRCTD_estemated_ct_label.png' ))
94+ plt .close ()
95+
96+ def cell_type_mean_assignment (self ):
97+ # cell type mean as decomposed cell gene expression
98+ ref_df = pd .DataFrame ([[ct , i ]for i , ct in enumerate (self .sc_data_process_marker .obs [self .cell_class_column ].astype ('category' ).cat .categories )], columns = ['cell_type' , 'cell_type_code' ])
99+ ref_df .index = ref_df .cell_type
100+ ref_df = ref_df .iloc [:,1 :]
101+
102+ x_decom = self .mu [ref_df .loc [np .array (self .cell_type_list )].cell_type_code .tolist ()]
103+ x_decom_adata = anndata .AnnData (X = x_decom .copy (), obs = self .seged_sp_adata .uns ['cell_locations' ].copy (), var = self .sc_data_process_marker .var )
104+ x_decom_adata .write (os .path .join (self .out_dir , 'cell_type_mean_bySDRCTD.h5ad' ))
105+
106+
107+ def LoadSCData (self ):
108+ # load sc data
109+ sc_data_process = anndata .read_h5ad (self .SC_Data )
110+ if 'Marker' in sc_data_process .var .columns :
111+ sc_data_process_marker = sc_data_process [:,sc_data_process .var ['Marker' ]]
112+ else :
113+ sc_data_process_marker = sc_data_process
114+
115+ if sc_data_process_marker .X .max () <= 30 :
116+ self .loggings .info (f'Maximum value: { sc_data_process_marker .X .max ()} , need to run exp' )
117+ try :
118+ sc_data_process_marker .X = np .exp (sc_data_process_marker .X ) - 1
119+ except :
120+ sc_data_process_marker .X = np .exp (sc_data_process_marker .X .toarray ()) - 1
121+
122+
123+ cell_type_array = np .array (sc_data_process_marker .obs [self .cell_class_column ])
124+ cell_type_class = np .unique (cell_type_array )
125+ df_category = sc_data_process_marker .obs [[self .cell_class_column ]].astype ('category' ).apply (lambda x : x .cat .codes )
126+
127+ # parameters: mean and cell type index
128+ cell_type_array_code = np .array (df_category [self .cell_class_column ])
129+ try :
130+ data = sc_data_process_marker .X .toarray ()
131+ except :
132+ data = sc_data_process_marker .X
133+
134+ n , d = data .shape
135+ q = cell_type_class .shape [0 ]
136+ self .loggings .info (f'scRNA-seq data shape: { data .shape } ' )
137+ self .loggings .info (f'scRNA-seq cell class number: { q } ' )
138+
139+ mu = np .zeros ((q , d ))
140+ for k in range (q ):
141+ mu [k ] = data [cell_type_array_code == k ].mean (0 ).squeeze ()
142+ self .mu = mu
143+ self .sc_data_process_marker = sc_data_process_marker
144+
145+
146+
147+
148+
149+
150+
151+ if __name__ == "__main__" :
152+ HEADER = """
153+ <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
154+ <>
155+ <> StarDist + RCTD
156+ <> Version: %s
157+ <> MIT License
158+ <>
159+ <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
160+ <> Software-related correspondence: %s or %s
161+ <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
162+ <> Visium data example:
163+ python <install path>/src/Cell_Type_Identification.py \\
164+ --cell_class_column cell_type \\
165+ --tissue heart \\
166+ --out_dir ./output \\
167+ --ST_Data ./output/heart/sp_adata_ns.h5ad \\
168+ --SC_Data ./Ckpts_scRefs/Heart_D2/Ref_Heart_sanger_D2.h5ad
169+ <><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><><>
170+ """
171+ def str2bool (v ):
172+ if isinstance (v , bool ):
173+ return v
174+ if v .lower () in ("yes" , "true" , "t" , "y" , "1" ):
175+ return True
176+ elif v .lower () in ("no" , "false" , "f" , "n" , "0" ):
177+ return False
178+ else :
179+ raise argparse .ArgumentTypeError ("Boolean value expected." )
180+
181+ parser = argparse .ArgumentParser (description = 'simulation sour_sep' )
182+ parser .add_argument ('--out_dir' , type = str , help = 'output path' , default = None )
183+ parser .add_argument ('--RCTD_results_dir' , type = str , help = 'RCTD results path' , default = None )
184+ parser .add_argument ('--RCTD_results_name' , type = str , help = 'RCTD results file\' s name' , default = 'InitProp' )
185+ parser .add_argument ('--ST_Data' , type = str , help = 'ST data path' , default = None )
186+ parser .add_argument ('--SC_Data' , type = str , help = 'single cell reference data path' , default = None )
187+ parser .add_argument ('--cell_class_column' , type = str , help = 'input cell class label column in scRef file' , default = 'cell_type' )
188+ parser .add_argument ('--cell_num_column' , type = str , help = 'cell number column in spatial file' , default = 'cell_count' )
189+ parser .add_argument ('--hs_ST' , action = "store_true" , help = 'high resolution ST data such as Slideseq, DBiT-seq, and HDST, MERFISH etc.' )
190+ parser .add_argument ("--VisiumCellsPlot" , type = str2bool , const = True , default = True , nargs = "?" , help = "whether to plot in VisiumCells mode or just scatter plot" )
191+ args = parser .parse_args ()
192+
193+ args .tissue = 'SDRCTD_results'
194+ if not os .path .exists (args .out_dir ):
195+ os .mkdir (args .out_dir )
196+ if not os .path .exists (os .path .join (args .out_dir ,args .tissue )):
197+ os .mkdir (os .path .join (args .out_dir ,args .tissue ))
198+ if args .RCTD_results_dir is None :
199+ args .RCTD_results_dir = args .out_dir
200+
201+
202+ sdr = SDRCTD (args .tissue ,args .out_dir , args .RCTD_results_dir , args .RCTD_results_name , args .ST_Data , args .SC_Data , cell_class_column = args .cell_class_column , hs_ST = args .hs_ST )
203+ sdr .single_cell_type_assignment (cell_num_column = args .cell_num_column , VisiumCellsPlot = args .VisiumCellsPlot )
204+ if args .SC_Data is not None :
205+ sdr .cell_type_mean_assignment ()
0 commit comments