Skip to content

Commit 3620ce1

Browse files
Merge pull request #423 from OmicsML/celltype_annotation_automl
Celltype annotation automl
2 parents c69d017 + d4f410f commit 3620ce1

File tree

106 files changed

+5815
-805
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+5815
-805
lines changed

dance/data/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import anndata
1010
import mudata
1111
import numpy as np
12+
import omegaconf
1213
import pandas as pd
1314
import scipy.sparse as sp
1415
import torch
1516

1617
from dance import logger
17-
from dance.typing import Any, Dict, FeatType, Iterator, List, Literal, Optional, Sequence, Tuple, Union
18+
from dance.typing import Any, Dict, FeatType, Iterator, List, ListConfig, Literal, Optional, Sequence, Tuple, Union
1819

1920

2021
def _ensure_iter(val: Optional[Union[List[str], str]]) -> Iterator[Optional[str]]:
@@ -34,7 +35,7 @@ def _check_types_and_sizes(types, sizes):
3435
raise TypeError(f"Found mixed types: {types}. Input configs must be either all str or all lists.")
3536
elif ((type_ := types.pop()) == list) and (len(sizes) > 1):
3637
raise ValueError(f"Found mixed sizes lists: {sizes}. Input configs must be of same length.")
37-
elif type_ not in (list, str):
38+
elif type_ not in (list, str, ListConfig):
3839
raise TypeError(f"Unknownn type {type_} found in config.")
3940

4041

@@ -240,7 +241,7 @@ def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool =
240241
label_configs = [j for i, j in config_dict.items() if i in self._LABEL_CONFIGS and j is not None]
241242

242243
# Check type and length consistencies for feature and label configs
243-
for i in (feature_configs, label_configs):
244+
for i in [feature_configs, label_configs]:
244245
types = set(map(type, i))
245246
sizes = set(map(len, i))
246247
_check_types_and_sizes(types, sizes)
@@ -249,6 +250,9 @@ def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool =
249250
for config_key, config_val in config_dict.items():
250251
# New config
251252
if config_key not in self.config:
253+
if isinstance(config_val, ListConfig):
254+
config_val = omegaconf.OmegaConf.to_object(config_val)
255+
logger.warning(f"transform ListConfig {config_val} to List")
252256
self.config[config_key] = config_val
253257
logger.info(f"Setting config {config_key!r} to {config_val!r}")
254258
continue

dance/datasets/singlemodality.py

Lines changed: 129 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pandas as pd
1313
import scanpy as sc
1414
from scipy.sparse import csr_matrix
15+
from sklearn.model_selection import train_test_split
1516

1617
from dance import logger
1718
from dance.data import Data
@@ -52,19 +53,42 @@ class CellTypeAnnotationDataset(BaseDataset):
5253

5354
def __init__(self, full_download=False, train_dataset=None, test_dataset=None, species=None, tissue=None,
5455
valid_dataset=None, train_dir="train", test_dir="test", valid_dir="valid", map_path="map",
55-
data_dir="./"):
56+
data_dir="./", train_as_valid=False, val_size=0.2):
5657
super().__init__(data_dir, full_download)
5758

5859
self.data_dir = data_dir
5960
self.train_dataset = train_dataset
6061
self.test_dataset = test_dataset
61-
self.valid_dataset = train_dataset if valid_dataset is None else valid_dataset
6262
self.species = species
6363
self.tissue = tissue
6464
self.train_dir = train_dir
6565
self.test_dir = test_dir
6666
self.valid_dir = valid_dir
6767
self.map_path = map_path
68+
self.train_as_valid = train_as_valid
69+
self.bench_url_dict = self.BENCH_URL_DICT.copy()
70+
self.available_data = self.AVAILABLE_DATA.copy()
71+
self.valid_dataset = valid_dataset
72+
if valid_dataset is None and self.train_as_valid:
73+
self.valid_dataset = train_dataset
74+
self.train2valid()
75+
self.val_size = val_size
76+
77+
def train2valid(self):
78+
logger.info("Copy train_dataset and use it as valid_dataset")
79+
temp_ava_data = self.available_data.copy()
80+
temp_ben_url_dict = self.bench_url_dict.copy()
81+
for data in self.available_data:
82+
if data["split"] == "train":
83+
end_data = data.copy()
84+
end_data['split'] = 'valid'
85+
temp_ava_data.append(end_data)
86+
87+
for k, v in self.bench_url_dict.items():
88+
if k.startswith("train"):
89+
temp_ben_url_dict[k.replace("train", "valid", 1)] = v
90+
self.available_data = temp_ava_data
91+
self.bench_url_dict = temp_ben_url_dict
6892

6993
def download_all(self):
7094
if self.is_complete():
@@ -87,7 +111,8 @@ def download_all(self):
87111

88112
def get_all_filenames(self, filetype: str = "csv", feat_suffix: str = "data", label_suffix: str = "celltype"):
89113
filenames = []
90-
for id in self.train_dataset + self.test_dataset + self.valid_dataset:
114+
for id in self.train_dataset + self.test_dataset + (self.valid_dataset
115+
if self.valid_dataset is not None else []):
91116
filenames.append(f"{self.species}_{self.tissue}{id}_{feat_suffix}.{filetype}")
92117
filenames.append(f"{self.species}_{self.tissue}{id}_{label_suffix}.{filetype}")
93118
return filenames
@@ -98,7 +123,7 @@ def download(self, download_map=True):
98123

99124
filenames = self.get_all_filenames()
100125
# Download training and testing data
101-
for name, url in self.BENCH_URL_DICT.items():
126+
for name, url in self.bench_url_dict.items():
102127
parts = name.split("_") # [train|test]_{species}_{tissue}{id}_[celltype|data].csv
103128
filename = "_".join(parts[1:])
104129
if filename in filenames:
@@ -115,7 +140,6 @@ def is_complete_all(self):
115140
check = [
116141
osp.join(self.data_dir, "train"),
117142
osp.join(self.data_dir, "test"),
118-
osp.join(self.data_dir, "valid"),
119143
osp.join(self.data_dir, "pretrained")
120144
]
121145
for i in check:
@@ -126,7 +150,7 @@ def is_complete_all(self):
126150

127151
def is_complete(self):
128152
"""Check if benchmarking data is complete."""
129-
for name in self.BENCH_URL_DICT:
153+
for name in self.bench_url_dict:
130154
if any(i not in name for i in (self.species, self.tissue)):
131155
continue
132156
filename = name[name.find(self.species):]
@@ -150,58 +174,101 @@ def is_complete(self):
150174
def _load_raw_data(self, ct_col: str = "Cell_type") -> Tuple[ad.AnnData, List[Set[str]], List[str], int]:
151175
species = self.species
152176
tissue = self.tissue
153-
train_dataset_ids = self.train_dataset
154-
test_dataset_ids = self.test_dataset
155-
valid_dataset_ids = self.valid_dataset
156-
data_dir = self.data_dir
157-
train_dir = osp.join(data_dir, self.train_dir)
158-
test_dir = osp.join(data_dir, self.test_dir)
159-
valid_dir = osp.join(data_dir, self.valid_dir)
160-
map_path = osp.join(data_dir, self.map_path, self.species)
161-
162-
# Load raw data
163-
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
164-
valid_feat_paths, valid_label_paths = self._get_data_paths(valid_dir, species, tissue, valid_dataset_ids)
165-
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
166-
train_feat, valid_feat, test_feat = (self._load_dfs(paths, transpose=True)
167-
for paths in (train_feat_paths, valid_feat_paths, test_feat_paths))
168-
train_label, valid_label, test_label = (self._load_dfs(paths)
169-
for paths in (train_label_paths, valid_label_paths, test_label_paths))
170-
171-
# Combine features (only use features that are present in the training data)
172-
train_size = train_feat.shape[0]
173-
valid_size = valid_feat.shape[0]
174-
feat_df = pd.concat(
175-
train_feat.align(valid_feat, axis=1, join="left", fill_value=0) +
176-
train_feat.align(test_feat, axis=1, join="left", fill_value=0)[1:]).fillna(0)
177-
adata = ad.AnnData(feat_df, dtype=np.float32)
178-
179-
# Convert cell type labels and map test cell type names to train
180-
cell_types = set(train_label[ct_col].unique())
181-
idx_to_label = sorted(cell_types)
182-
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
183-
train_labels, valid_labels, test_labels = train_label[ct_col].tolist(), [], []
184-
for i in valid_label[ct_col]:
185-
valid_labels.append(i if i in cell_types else cell_type_mappings.get(i))
186-
for i in test_label[ct_col]:
187-
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
188-
labels: List[Set[str]] = train_labels + valid_labels + test_labels
189-
190-
logger.debug("Mapped valid cell-types:")
191-
for i, j, k in zip(valid_label.index, valid_label[ct_col], valid_labels):
192-
logger.debug(f"{i}:{j}\t-> {k}")
193-
194-
logger.debug("Mapped test cell-types:")
195-
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
196-
logger.debug(f"{i}:{j}\t-> {k}")
197-
198-
logger.info(f"Loaded expression data: {adata}")
199-
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
200-
logger.info(f"Number of valid samples: {valid_feat.shape[0]:,}")
201-
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
202-
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")
203-
204-
return adata, labels, idx_to_label, train_size, valid_size
177+
valid_feat = None
178+
if self.valid_dataset is not None:
179+
train_dataset_ids = self.train_dataset
180+
test_dataset_ids = self.test_dataset
181+
valid_dataset_ids = self.valid_dataset
182+
data_dir = self.data_dir
183+
train_dir = osp.join(data_dir, self.train_dir)
184+
test_dir = osp.join(data_dir, self.test_dir)
185+
valid_dir = osp.join(data_dir, self.valid_dir)
186+
map_path = osp.join(data_dir, self.map_path, self.species)
187+
188+
# Load raw data
189+
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
190+
valid_feat_paths, valid_label_paths = self._get_data_paths(valid_dir, species, tissue, valid_dataset_ids)
191+
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
192+
train_feat, valid_feat, test_feat = (self._load_dfs(paths, transpose=True)
193+
for paths in (train_feat_paths, valid_feat_paths, test_feat_paths))
194+
train_label, valid_label, test_label = (self._load_dfs(paths)
195+
for paths in (train_label_paths, valid_label_paths,
196+
test_label_paths))
197+
else:
198+
train_dataset_ids = self.train_dataset
199+
test_dataset_ids = self.test_dataset
200+
data_dir = self.data_dir
201+
train_dir = osp.join(data_dir, self.train_dir)
202+
test_dir = osp.join(data_dir, self.test_dir)
203+
map_path = osp.join(data_dir, self.map_path, self.species)
204+
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
205+
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
206+
train_feat, test_feat = (self._load_dfs(paths, transpose=True)
207+
for paths in (train_feat_paths, test_feat_paths))
208+
train_label, test_label = (self._load_dfs(paths) for paths in (train_label_paths, test_label_paths))
209+
if self.val_size > 0:
210+
train_feat, valid_feat, train_label, valid_label = train_test_split(train_feat, train_label,
211+
test_size=self.val_size)
212+
if valid_feat is not None:
213+
# Combine features (only use features that are present in the training data)
214+
train_size = train_feat.shape[0]
215+
valid_size = valid_feat.shape[0]
216+
feat_df = pd.concat(
217+
train_feat.align(valid_feat, axis=1, join="left", fill_value=0) +
218+
train_feat.align(test_feat, axis=1, join="left", fill_value=0)[1:]).fillna(0)
219+
adata = ad.AnnData(feat_df, dtype=np.float32)
220+
221+
# Convert cell type labels and map test cell type names to train
222+
cell_types = set(train_label[ct_col].unique())
223+
idx_to_label = sorted(cell_types)
224+
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
225+
train_labels, valid_labels, test_labels = train_label[ct_col].tolist(), [], []
226+
for i in valid_label[ct_col]:
227+
valid_labels.append(i if i in cell_types else cell_type_mappings.get(i))
228+
for i in test_label[ct_col]:
229+
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
230+
labels: List[Set[str]] = train_labels + valid_labels + test_labels
231+
232+
logger.debug("Mapped valid cell-types:")
233+
for i, j, k in zip(valid_label.index, valid_label[ct_col], valid_labels):
234+
logger.debug(f"{i}:{j}\t-> {k}")
235+
236+
logger.debug("Mapped test cell-types:")
237+
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
238+
logger.debug(f"{i}:{j}\t-> {k}")
239+
240+
logger.info(f"Loaded expression data: {adata}")
241+
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
242+
logger.info(f"Number of valid samples: {valid_feat.shape[0]:,}")
243+
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
244+
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")
245+
246+
return adata, labels, idx_to_label, train_size, valid_size
247+
else:
248+
# Combine features (only use features that are present in the training data)
249+
train_size = train_feat.shape[0]
250+
feat_df = pd.concat(train_feat.align(test_feat, axis=1, join="left", fill_value=0)).fillna(0)
251+
adata = ad.AnnData(feat_df, dtype=np.float32)
252+
253+
# Convert cell type labels and map test cell type names to train
254+
cell_types = set(train_label[ct_col].unique())
255+
idx_to_label = sorted(cell_types)
256+
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
257+
train_labels, test_labels = train_label[ct_col].tolist(), []
258+
for i in test_label[ct_col]:
259+
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
260+
labels: List[Set[str]] = train_labels + test_labels
261+
262+
logger.debug("Mapped test cell-types:")
263+
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
264+
logger.debug(f"{i}:{j}\t-> {k}")
265+
266+
logger.info(f"Loaded expression data: {adata}")
267+
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
268+
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
269+
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")
270+
271+
return adata, labels, idx_to_label, train_size, 0
205272

206273
def _raw_to_dance(self, raw_data):
207274
adata, cell_labels, idx_to_label, train_size, valid_size = raw_data
@@ -290,9 +357,10 @@ def is_complete(self):
290357
return osp.exists(self.data_path)
291358

292359
def _load_raw_data(self) -> Tuple[ad.AnnData, np.ndarray]:
293-
with h5py.File(self.data_path, "r") as f:
294-
x = np.array(f["X"])
295-
y = np.array(f["Y"])
360+
with open(self.data_path, "rb") as f_o:
361+
with h5py.File(f_o, "r") as f:
362+
x = np.array(f["X"])
363+
y = np.array(f["Y"])
296364
adata = ad.AnnData(x, dtype=np.float32)
297365
return adata, y
298366

dance/datasets/spatial.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class SpatialLIBDDataset(BaseDataset):
2424

2525
_DISPLAY_ATTRS = ("data_id", )
2626
URL_DICT = {
27-
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=0",
27+
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=1",
2828
"151507": "https://www.dropbox.com/sh/m3554vfrdzbwv2c/AACGsFNVKx8rjBgvF7Pcm2L7a?dl=1",
2929
"151508": "https://www.dropbox.com/sh/tm47u3fre8692zt/AAAJJf8-za_Lpw614ft096qqa?dl=1",
3030
"151509": "https://www.dropbox.com/sh/hihr7906vyirjet/AACslV5mKIkF2CF5QqE1LE6ya?dl=1",
@@ -47,11 +47,12 @@ class SpatialLIBDDataset(BaseDataset):
4747
}
4848
AVAILABLE_DATA = sorted(URL_DICT)
4949

50-
def __init__(self, root=".", full_download=False, data_id="151673", data_dir="data/spatial"):
50+
def __init__(self, root=".", full_download=False, data_id="151673", data_dir="data/spatial", sample_file=None):
5151
super().__init__(root, full_download)
5252

5353
self.data_id = data_id
5454
self.data_dir = data_dir + "/{}".format(data_id)
55+
self.sample_file = sample_file
5556

5657
def download_all(self):
5758
logger.info(f"All data includes {len(self.URL_DICT)} datasets: {list(self.URL_DICT)}")
@@ -147,7 +148,11 @@ def _raw_to_dance(self, raw_data):
147148
adata.obsm["spatial"] = xy.set_index(adata.obs_names)
148149
adata.obsm["spatial_pixel"] = xy_pixel.set_index(adata.obs_names)
149150
adata.uns["image"] = img
150-
151+
if self.sample_file is not None:
152+
sample_file = osp.join(self.data_dir, self.sample_file)
153+
with open(sample_file) as file:
154+
sample_index = [int(line.strip()) for line in file]
155+
adata = adata[sample_index]
151156
data = Data(adata, train_size="all")
152157
return data
153158

dance/metadata/clustering.csv

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
10X_PBMC,https://www.dropbox.com/s/pfunm27qzgfpj3u/10X_PBMC.h5?dl=1
22
mouse_lung_cell,https://dl.dropboxusercontent.com/scl/fi/6h4ewvj1n64mrppz56v7s/mouse_lung_cell.h5?rlkey=6snhzzkv6f7vmshkvne9leimu&dl=0
33
human_pbmc2_cell,https://dl.dropboxusercontent.com/scl/fi/c69gv5btxfvpcmkcqj4zx/human_pbmc_cell.h5?rlkey=jzi6u9qs2qf4nixr6a0n48mc0&dl=0
4-
human_pbmc_cell,https://dl.dropboxusercontent.com/scl/fi/2by36reg6wjq6hxlytljx/human_ILCS_cell.h5?rlkey=mu4pz7quxspf9qgzx5wc4aaet&dl=0
4+
human_ILCS_cell,https://dl.dropboxusercontent.com/scl/fi/2by36reg6wjq6hxlytljx/human_ILCS_cell.h5?rlkey=mu4pz7quxspf9qgzx5wc4aaet&dl=0
55
human_skin_cell,https://dl.dropboxusercontent.com/scl/fi/5gd3kcz307r42s7u3di3q/human_skin_cell.h5?rlkey=2hat0jeze2cn2uqnu4p7g7yhw&dl=0
66
mouse_ES_cell,https://www.dropbox.com/s/zbuku7oznvji8jk/mouse_ES_cell.h5?dl=1
77
mouse_bladder_cell,https://www.dropbox.com/s/xxtnomx5zrifdwi/mouse_bladder_cell.h5?dl=1
8-
mouse_kidney_10x,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=
9-
mouse_kidney_cell,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=0
10-
mouse_kidney_cl2,https://dl.dropboxusercontent.com/scl/fi/d0uh8qqw4q4f0748yq5db/mouse_kidney_drop.h5?rlkey=3onfglh6sv6q91c5e1ns5lc5h&dl=0
8+
mouse_kidney_10x,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=1
9+
mouse_kidney_cell,https://www.dropbox.com/scl/fi/qrkyu9qhfcj43smlfqygq/mouse_kidney_cell.h5?rlkey=a0uyhgxfty4iti0k83xx9gtsc&dl=1
10+
mouse_kidney_cl2,https://www.dropbox.com/scl/fi/g60cr1t6dqvtv5zei4h3m/mouse_kidney_cl2.h5?rlkey=gth7bakq4tugztiv1r1akgy8l&dl=1
1111
mouse_kidney_drop,https://dl.dropboxusercontent.com/scl/fi/d0uh8qqw4q4f0748yq5db/mouse_kidney_drop.h5?rlkey=3onfglh6sv6q91c5e1ns5lc5h&dl=0
1212
worm_neuron_cell,https://www.dropbox.com/s/58fkgemi2gcnp2k/worm_neuron_cell.h5?dl=1

0 commit comments

Comments
 (0)