1212import pandas as pd
1313import scanpy as sc
1414from scipy .sparse import csr_matrix
15+ from sklearn .model_selection import train_test_split
1516
1617from dance import logger
1718from dance .data import Data
@@ -52,19 +53,42 @@ class CellTypeAnnotationDataset(BaseDataset):
5253
5354 def __init__ (self , full_download = False , train_dataset = None , test_dataset = None , species = None , tissue = None ,
5455 valid_dataset = None , train_dir = "train" , test_dir = "test" , valid_dir = "valid" , map_path = "map" ,
55- data_dir = "./" ):
56+ data_dir = "./" , train_as_valid = False , val_size = 0.2 ):
5657 super ().__init__ (data_dir , full_download )
5758
5859 self .data_dir = data_dir
5960 self .train_dataset = train_dataset
6061 self .test_dataset = test_dataset
61- self .valid_dataset = train_dataset if valid_dataset is None else valid_dataset
6262 self .species = species
6363 self .tissue = tissue
6464 self .train_dir = train_dir
6565 self .test_dir = test_dir
6666 self .valid_dir = valid_dir
6767 self .map_path = map_path
68+ self .train_as_valid = train_as_valid
69+ self .bench_url_dict = self .BENCH_URL_DICT .copy ()
70+ self .available_data = self .AVAILABLE_DATA .copy ()
71+ self .valid_dataset = valid_dataset
72+ if valid_dataset is None and self .train_as_valid :
73+ self .valid_dataset = train_dataset
74+ self .train2valid ()
75+ self .val_size = val_size
76+
77+ def train2valid (self ):
78+ logger .info ("Copy train_dataset and use it as valid_dataset" )
79+ temp_ava_data = self .available_data .copy ()
80+ temp_ben_url_dict = self .bench_url_dict .copy ()
81+ for data in self .available_data :
82+ if data ["split" ] == "train" :
83+ end_data = data .copy ()
84+ end_data ['split' ] = 'valid'
85+ temp_ava_data .append (end_data )
86+
87+ for k , v in self .bench_url_dict .items ():
88+ if k .startswith ("train" ):
89+ temp_ben_url_dict [k .replace ("train" , "valid" , 1 )] = v
90+ self .available_data = temp_ava_data
91+ self .bench_url_dict = temp_ben_url_dict
6892
6993 def download_all (self ):
7094 if self .is_complete ():
@@ -87,7 +111,8 @@ def download_all(self):
87111
88112 def get_all_filenames (self , filetype : str = "csv" , feat_suffix : str = "data" , label_suffix : str = "celltype" ):
89113 filenames = []
90- for id in self .train_dataset + self .test_dataset + self .valid_dataset :
114+ for id in self .train_dataset + self .test_dataset + (self .valid_dataset
115+ if self .valid_dataset is not None else []):
91116 filenames .append (f"{ self .species } _{ self .tissue } { id } _{ feat_suffix } .{ filetype } " )
92117 filenames .append (f"{ self .species } _{ self .tissue } { id } _{ label_suffix } .{ filetype } " )
93118 return filenames
@@ -98,7 +123,7 @@ def download(self, download_map=True):
98123
99124 filenames = self .get_all_filenames ()
100125 # Download training and testing data
101- for name , url in self .BENCH_URL_DICT .items ():
126+ for name , url in self .bench_url_dict .items ():
102127 parts = name .split ("_" ) # [train|test]_{species}_{tissue}{id}_[celltype|data].csv
103128 filename = "_" .join (parts [1 :])
104129 if filename in filenames :
@@ -115,7 +140,6 @@ def is_complete_all(self):
115140 check = [
116141 osp .join (self .data_dir , "train" ),
117142 osp .join (self .data_dir , "test" ),
118- osp .join (self .data_dir , "valid" ),
119143 osp .join (self .data_dir , "pretrained" )
120144 ]
121145 for i in check :
@@ -126,7 +150,7 @@ def is_complete_all(self):
126150
127151 def is_complete (self ):
128152 """Check if benchmarking data is complete."""
129- for name in self .BENCH_URL_DICT :
153+ for name in self .bench_url_dict :
130154 if any (i not in name for i in (self .species , self .tissue )):
131155 continue
132156 filename = name [name .find (self .species ):]
@@ -150,58 +174,101 @@ def is_complete(self):
150174 def _load_raw_data (self , ct_col : str = "Cell_type" ) -> Tuple [ad .AnnData , List [Set [str ]], List [str ], int ]:
151175 species = self .species
152176 tissue = self .tissue
153- train_dataset_ids = self .train_dataset
154- test_dataset_ids = self .test_dataset
155- valid_dataset_ids = self .valid_dataset
156- data_dir = self .data_dir
157- train_dir = osp .join (data_dir , self .train_dir )
158- test_dir = osp .join (data_dir , self .test_dir )
159- valid_dir = osp .join (data_dir , self .valid_dir )
160- map_path = osp .join (data_dir , self .map_path , self .species )
161-
162- # Load raw data
163- train_feat_paths , train_label_paths = self ._get_data_paths (train_dir , species , tissue , train_dataset_ids )
164- valid_feat_paths , valid_label_paths = self ._get_data_paths (valid_dir , species , tissue , valid_dataset_ids )
165- test_feat_paths , test_label_paths = self ._get_data_paths (test_dir , species , tissue , test_dataset_ids )
166- train_feat , valid_feat , test_feat = (self ._load_dfs (paths , transpose = True )
167- for paths in (train_feat_paths , valid_feat_paths , test_feat_paths ))
168- train_label , valid_label , test_label = (self ._load_dfs (paths )
169- for paths in (train_label_paths , valid_label_paths , test_label_paths ))
170-
171- # Combine features (only use features that are present in the training data)
172- train_size = train_feat .shape [0 ]
173- valid_size = valid_feat .shape [0 ]
174- feat_df = pd .concat (
175- train_feat .align (valid_feat , axis = 1 , join = "left" , fill_value = 0 ) +
176- train_feat .align (test_feat , axis = 1 , join = "left" , fill_value = 0 )[1 :]).fillna (0 )
177- adata = ad .AnnData (feat_df , dtype = np .float32 )
178-
179- # Convert cell type labels and map test cell type names to train
180- cell_types = set (train_label [ct_col ].unique ())
181- idx_to_label = sorted (cell_types )
182- cell_type_mappings : Dict [str , Set [str ]] = self .get_map_dict (map_path , tissue )
183- train_labels , valid_labels , test_labels = train_label [ct_col ].tolist (), [], []
184- for i in valid_label [ct_col ]:
185- valid_labels .append (i if i in cell_types else cell_type_mappings .get (i ))
186- for i in test_label [ct_col ]:
187- test_labels .append (i if i in cell_types else cell_type_mappings .get (i ))
188- labels : List [Set [str ]] = train_labels + valid_labels + test_labels
189-
190- logger .debug ("Mapped valid cell-types:" )
191- for i , j , k in zip (valid_label .index , valid_label [ct_col ], valid_labels ):
192- logger .debug (f"{ i } :{ j } \t -> { k } " )
193-
194- logger .debug ("Mapped test cell-types:" )
195- for i , j , k in zip (test_label .index , test_label [ct_col ], test_labels ):
196- logger .debug (f"{ i } :{ j } \t -> { k } " )
197-
198- logger .info (f"Loaded expression data: { adata } " )
199- logger .info (f"Number of training samples: { train_feat .shape [0 ]:,} " )
200- logger .info (f"Number of valid samples: { valid_feat .shape [0 ]:,} " )
201- logger .info (f"Number of testing samples: { test_feat .shape [0 ]:,} " )
202- logger .info (f"Cell-types (n={ len (idx_to_label )} ):\n { pprint .pformat (idx_to_label )} " )
203-
204- return adata , labels , idx_to_label , train_size , valid_size
177+ valid_feat = None
178+ if self .valid_dataset is not None :
179+ train_dataset_ids = self .train_dataset
180+ test_dataset_ids = self .test_dataset
181+ valid_dataset_ids = self .valid_dataset
182+ data_dir = self .data_dir
183+ train_dir = osp .join (data_dir , self .train_dir )
184+ test_dir = osp .join (data_dir , self .test_dir )
185+ valid_dir = osp .join (data_dir , self .valid_dir )
186+ map_path = osp .join (data_dir , self .map_path , self .species )
187+
188+ # Load raw data
189+ train_feat_paths , train_label_paths = self ._get_data_paths (train_dir , species , tissue , train_dataset_ids )
190+ valid_feat_paths , valid_label_paths = self ._get_data_paths (valid_dir , species , tissue , valid_dataset_ids )
191+ test_feat_paths , test_label_paths = self ._get_data_paths (test_dir , species , tissue , test_dataset_ids )
192+ train_feat , valid_feat , test_feat = (self ._load_dfs (paths , transpose = True )
193+ for paths in (train_feat_paths , valid_feat_paths , test_feat_paths ))
194+ train_label , valid_label , test_label = (self ._load_dfs (paths )
195+ for paths in (train_label_paths , valid_label_paths ,
196+ test_label_paths ))
197+ else :
198+ train_dataset_ids = self .train_dataset
199+ test_dataset_ids = self .test_dataset
200+ data_dir = self .data_dir
201+ train_dir = osp .join (data_dir , self .train_dir )
202+ test_dir = osp .join (data_dir , self .test_dir )
203+ map_path = osp .join (data_dir , self .map_path , self .species )
204+ train_feat_paths , train_label_paths = self ._get_data_paths (train_dir , species , tissue , train_dataset_ids )
205+ test_feat_paths , test_label_paths = self ._get_data_paths (test_dir , species , tissue , test_dataset_ids )
206+ train_feat , test_feat = (self ._load_dfs (paths , transpose = True )
207+ for paths in (train_feat_paths , test_feat_paths ))
208+ train_label , test_label = (self ._load_dfs (paths ) for paths in (train_label_paths , test_label_paths ))
209+ if self .val_size > 0 :
210+ train_feat , valid_feat , train_label , valid_label = train_test_split (train_feat , train_label ,
211+ test_size = self .val_size )
212+ if valid_feat is not None :
213+ # Combine features (only use features that are present in the training data)
214+ train_size = train_feat .shape [0 ]
215+ valid_size = valid_feat .shape [0 ]
216+ feat_df = pd .concat (
217+ train_feat .align (valid_feat , axis = 1 , join = "left" , fill_value = 0 ) +
218+ train_feat .align (test_feat , axis = 1 , join = "left" , fill_value = 0 )[1 :]).fillna (0 )
219+ adata = ad .AnnData (feat_df , dtype = np .float32 )
220+
221+ # Convert cell type labels and map test cell type names to train
222+ cell_types = set (train_label [ct_col ].unique ())
223+ idx_to_label = sorted (cell_types )
224+ cell_type_mappings : Dict [str , Set [str ]] = self .get_map_dict (map_path , tissue )
225+ train_labels , valid_labels , test_labels = train_label [ct_col ].tolist (), [], []
226+ for i in valid_label [ct_col ]:
227+ valid_labels .append (i if i in cell_types else cell_type_mappings .get (i ))
228+ for i in test_label [ct_col ]:
229+ test_labels .append (i if i in cell_types else cell_type_mappings .get (i ))
230+ labels : List [Set [str ]] = train_labels + valid_labels + test_labels
231+
232+ logger .debug ("Mapped valid cell-types:" )
233+ for i , j , k in zip (valid_label .index , valid_label [ct_col ], valid_labels ):
234+ logger .debug (f"{ i } :{ j } \t -> { k } " )
235+
236+ logger .debug ("Mapped test cell-types:" )
237+ for i , j , k in zip (test_label .index , test_label [ct_col ], test_labels ):
238+ logger .debug (f"{ i } :{ j } \t -> { k } " )
239+
240+ logger .info (f"Loaded expression data: { adata } " )
241+ logger .info (f"Number of training samples: { train_feat .shape [0 ]:,} " )
242+ logger .info (f"Number of valid samples: { valid_feat .shape [0 ]:,} " )
243+ logger .info (f"Number of testing samples: { test_feat .shape [0 ]:,} " )
244+ logger .info (f"Cell-types (n={ len (idx_to_label )} ):\n { pprint .pformat (idx_to_label )} " )
245+
246+ return adata , labels , idx_to_label , train_size , valid_size
247+ else :
248+ # Combine features (only use features that are present in the training data)
249+ train_size = train_feat .shape [0 ]
250+ feat_df = pd .concat (train_feat .align (test_feat , axis = 1 , join = "left" , fill_value = 0 )).fillna (0 )
251+ adata = ad .AnnData (feat_df , dtype = np .float32 )
252+
253+ # Convert cell type labels and map test cell type names to train
254+ cell_types = set (train_label [ct_col ].unique ())
255+ idx_to_label = sorted (cell_types )
256+ cell_type_mappings : Dict [str , Set [str ]] = self .get_map_dict (map_path , tissue )
257+ train_labels , test_labels = train_label [ct_col ].tolist (), []
258+ for i in test_label [ct_col ]:
259+ test_labels .append (i if i in cell_types else cell_type_mappings .get (i ))
260+ labels : List [Set [str ]] = train_labels + test_labels
261+
262+ logger .debug ("Mapped test cell-types:" )
263+ for i , j , k in zip (test_label .index , test_label [ct_col ], test_labels ):
264+ logger .debug (f"{ i } :{ j } \t -> { k } " )
265+
266+ logger .info (f"Loaded expression data: { adata } " )
267+ logger .info (f"Number of training samples: { train_feat .shape [0 ]:,} " )
268+ logger .info (f"Number of testing samples: { test_feat .shape [0 ]:,} " )
269+ logger .info (f"Cell-types (n={ len (idx_to_label )} ):\n { pprint .pformat (idx_to_label )} " )
270+
271+ return adata , labels , idx_to_label , train_size , 0
205272
206273 def _raw_to_dance (self , raw_data ):
207274 adata , cell_labels , idx_to_label , train_size , valid_size = raw_data
@@ -290,9 +357,10 @@ def is_complete(self):
290357 return osp .exists (self .data_path )
291358
292359 def _load_raw_data (self ) -> Tuple [ad .AnnData , np .ndarray ]:
293- with h5py .File (self .data_path , "r" ) as f :
294- x = np .array (f ["X" ])
295- y = np .array (f ["Y" ])
360+ with open (self .data_path , "rb" ) as f_o :
361+ with h5py .File (f_o , "r" ) as f :
362+ x = np .array (f ["X" ])
363+ y = np .array (f ["Y" ])
296364 adata = ad .AnnData (x , dtype = np .float32 )
297365 return adata , y
298366
0 commit comments