|
18 | 18 | from scipy.spatial.distance import cdist, directed_hausdorff, jaccard, jensenshannon |
19 | 19 | from sklearn.metrics.pairwise import cosine_similarity, rbf_kernel |
20 | 20 |
|
| 21 | +from dance.settings import METADIR |
| 22 | + |
21 | 23 | # Suppress scipy warnings for constant input in Pearson correlation |
22 | 24 | warnings.filterwarnings("ignore", message="An input array is constant") |
23 | 25 | from dance.datasets.singlemodality import CellTypeAnnotationDataset |
24 | 26 |
|
25 | 27 |
|
26 | 28 | def get_anndata(tissue: str = "Blood", species: str = "human", filetype: str = "h5ad", train_dataset=[], |
27 | 29 | test_dataset=[], valid_dataset=[], data_dir="../temp_data"): |
28 | | - if train_dataset == ['84230ea4-998d-4aa8-8456-81dd54ce23af']: |
29 | | - pass |
| 30 | + |
| 31 | + def find_dataset_in_metadata(datasets, tissue): |
| 32 | + datasets_in_metadata = [] |
| 33 | + for dataset_id in datasets: |
| 34 | + all_datasets = pd.read_csv(METADIR / "scdeepsort.csv", header=0, skiprows=[i for i in range(1, 68)]) |
| 35 | + for collect_dataset in all_datasets[all_datasets["tissue"] == tissue]["data_fname"].tolist(): |
| 36 | + if dataset_id in collect_dataset: |
| 37 | + datasets_in_metadata.append( |
| 38 | + (collect_dataset.split(tissue)[1] + |
| 39 | + (tissue + collect_dataset.split(tissue)[2] if len(collect_dataset.split(tissue)) >= 3 else '') |
| 40 | + ).split('_')[0]) |
| 41 | + break |
| 42 | + return datasets_in_metadata |
| 43 | + |
| 44 | + train_dataset = find_dataset_in_metadata(train_dataset, tissue) |
| 45 | + valid_dataset = find_dataset_in_metadata(valid_dataset, tissue) |
| 46 | + test_dataset = find_dataset_in_metadata(test_dataset, tissue) |
30 | 47 | data = CellTypeAnnotationDataset(train_dataset=train_dataset, test_dataset=test_dataset, |
31 | 48 | valid_dataset=valid_dataset, data_dir=data_dir, tissue=tissue, species=species, |
32 | 49 | filetype=filetype).load_data() |
@@ -89,6 +106,8 @@ def filter_gene(self, n_top_genes=3000): |
89 | 106 | Number of top variable genes to select |
90 | 107 |
|
91 | 108 | """ |
| 109 | + sc.pp.filter_genes(self.origin_adata1, min_counts=3) |
| 110 | + sc.pp.filter_genes(self.origin_adata2, min_counts=3) |
92 | 111 | sc.pp.highly_variable_genes(self.origin_adata1, n_top_genes=n_top_genes, flavor='seurat_v3') |
93 | 112 | sc.pp.highly_variable_genes(self.origin_adata2, n_top_genes=n_top_genes, flavor='seurat_v3') |
94 | 113 |
|
@@ -195,6 +214,20 @@ def jsd(p, q): |
195 | 214 | similarity_matrix = 1 - divergence_matrix |
196 | 215 | return np.nanmean(similarity_matrix) |
197 | 216 |
|
| 217 | + def compute_mmd_alternative(self) -> float: |
| 218 | + X = self.X |
| 219 | + Y = self.Y |
| 220 | + gamma = 1.0 |
| 221 | + K_XX = rbf_kernel(X, X, gamma) |
| 222 | + K_YY = rbf_kernel(Y, Y, gamma) |
| 223 | + K_XY = rbf_kernel(X, Y, gamma) |
| 224 | + n_x = X.shape[0] |
| 225 | + n_y = Y.shape[0] |
| 226 | + mmd = (K_XX.sum() - np.trace(K_XX)) / (n_x * (n_x - 1)) \ |
| 227 | + + (K_YY.sum() - np.trace(K_YY)) / (n_y * (n_y - 1)) \ |
| 228 | + - 2 * K_XY.mean() |
| 229 | + return 1 / (1 + np.sqrt(max(mmd, 0))) |
| 230 | + |
198 | 231 | def compute_mmd(self) -> float: |
199 | 232 | """Compute Maximum Mean Discrepancy between datasets. |
200 | 233 |
|
@@ -359,6 +392,12 @@ def get_dataset_info(data: ad.AnnData): |
359 | 392 | con_sim["n_measured_vars"] = np.mean(data.obs["n_measured_vars"]) |
360 | 393 | con_sim["cell_num"] = len(data.obs) |
361 | 394 | con_sim["gene_num"] = len(data.var) |
| 395 | + if "n_counts" not in data.obs.columns: |
| 396 | + if scipy.sparse.issparse(data.X): |
| 397 | + cell_counts = np.array(data.X.sum(axis=1)).flatten() |
| 398 | + else: |
| 399 | + cell_counts = data.X.sum(axis=1) |
| 400 | + data.obs["n_counts"] = cell_counts |
362 | 401 | con_sim["n_counts_mean"] = np.mean(data.obs["n_counts"]) |
363 | 402 | con_sim["n_counts_var"] = np.var(data.obs["n_counts"]) |
364 | 403 | # if "n_counts" not in data.var.columns: |
@@ -404,9 +443,9 @@ def get_targets(dataset_truth: str): |
404 | 443 | sim_targets = [] |
405 | 444 | for method in self.methods: |
406 | 445 | query_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata1_name, |
407 | | - f"{method}_best_yaml"].iloc[0] |
| 446 | + f"{method}_step2_best_yaml"].iloc[0] |
408 | 447 | atlas_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata2_name, |
409 | | - f"{method}_best_yaml"].iloc[0] |
| 448 | + f"{method}_step2_best_yaml"].iloc[0] |
410 | 449 | if type(atlas_dataset_truth) == float and np.isnan(atlas_dataset_truth): |
411 | 450 | return 0 |
412 | 451 | query_targets = get_targets(query_dataset_truth) |
|
0 commit comments