Skip to content

Commit 8037b1c

Browse files
committed
update ans
1 parent 2baf956 commit 8037b1c

File tree

5 files changed

+57
-373
lines changed

5 files changed

+57
-373
lines changed

dance/atlas/sc_similarity/anndata_similarity.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,32 @@
1818
from scipy.spatial.distance import cdist, directed_hausdorff, jaccard, jensenshannon
1919
from sklearn.metrics.pairwise import cosine_similarity, rbf_kernel
2020

21+
from dance.settings import METADIR
22+
2123
# Suppress scipy warnings for constant input in Pearson correlation
2224
warnings.filterwarnings("ignore", message="An input array is constant")
2325
from dance.datasets.singlemodality import CellTypeAnnotationDataset
2426

2527

2628
def get_anndata(tissue: str = "Blood", species: str = "human", filetype: str = "h5ad", train_dataset=[],
2729
test_dataset=[], valid_dataset=[], data_dir="../temp_data"):
28-
if train_dataset == ['84230ea4-998d-4aa8-8456-81dd54ce23af']:
29-
pass
30+
31+
def find_dataset_in_metadata(datasets, tissue):
32+
datasets_in_metadata = []
33+
for dataset_id in datasets:
34+
all_datasets = pd.read_csv(METADIR / "scdeepsort.csv", header=0, skiprows=[i for i in range(1, 68)])
35+
for collect_dataset in all_datasets[all_datasets["tissue"] == tissue]["data_fname"].tolist():
36+
if dataset_id in collect_dataset:
37+
datasets_in_metadata.append(
38+
(collect_dataset.split(tissue)[1] +
39+
(tissue + collect_dataset.split(tissue)[2] if len(collect_dataset.split(tissue)) >= 3 else '')
40+
).split('_')[0])
41+
break
42+
return datasets_in_metadata
43+
44+
train_dataset = find_dataset_in_metadata(train_dataset, tissue)
45+
valid_dataset = find_dataset_in_metadata(valid_dataset, tissue)
46+
test_dataset = find_dataset_in_metadata(test_dataset, tissue)
3047
data = CellTypeAnnotationDataset(train_dataset=train_dataset, test_dataset=test_dataset,
3148
valid_dataset=valid_dataset, data_dir=data_dir, tissue=tissue, species=species,
3249
filetype=filetype).load_data()
@@ -89,6 +106,8 @@ def filter_gene(self, n_top_genes=3000):
89106
Number of top variable genes to select
90107
91108
"""
109+
sc.pp.filter_genes(self.origin_adata1, min_counts=3)
110+
sc.pp.filter_genes(self.origin_adata2, min_counts=3)
92111
sc.pp.highly_variable_genes(self.origin_adata1, n_top_genes=n_top_genes, flavor='seurat_v3')
93112
sc.pp.highly_variable_genes(self.origin_adata2, n_top_genes=n_top_genes, flavor='seurat_v3')
94113

@@ -195,6 +214,20 @@ def jsd(p, q):
195214
similarity_matrix = 1 - divergence_matrix
196215
return np.nanmean(similarity_matrix)
197216

217+
def compute_mmd_alternative(self) -> float:
218+
X = self.X
219+
Y = self.Y
220+
gamma = 1.0
221+
K_XX = rbf_kernel(X, X, gamma)
222+
K_YY = rbf_kernel(Y, Y, gamma)
223+
K_XY = rbf_kernel(X, Y, gamma)
224+
n_x = X.shape[0]
225+
n_y = Y.shape[0]
226+
mmd = (K_XX.sum() - np.trace(K_XX)) / (n_x * (n_x - 1)) \
227+
+ (K_YY.sum() - np.trace(K_YY)) / (n_y * (n_y - 1)) \
228+
- 2 * K_XY.mean()
229+
return 1 / (1 + np.sqrt(max(mmd, 0)))
230+
198231
def compute_mmd(self) -> float:
199232
"""Compute Maximum Mean Discrepancy between datasets.
200233
@@ -359,6 +392,12 @@ def get_dataset_info(data: ad.AnnData):
359392
con_sim["n_measured_vars"] = np.mean(data.obs["n_measured_vars"])
360393
con_sim["cell_num"] = len(data.obs)
361394
con_sim["gene_num"] = len(data.var)
395+
if "n_counts" not in data.obs.columns:
396+
if scipy.sparse.issparse(data.X):
397+
cell_counts = np.array(data.X.sum(axis=1)).flatten()
398+
else:
399+
cell_counts = data.X.sum(axis=1)
400+
data.obs["n_counts"] = cell_counts
362401
con_sim["n_counts_mean"] = np.mean(data.obs["n_counts"])
363402
con_sim["n_counts_var"] = np.var(data.obs["n_counts"])
364403
# if "n_counts" not in data.var.columns:
@@ -404,9 +443,9 @@ def get_targets(dataset_truth: str):
404443
sim_targets = []
405444
for method in self.methods:
406445
query_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata1_name,
407-
f"{method}_best_yaml"].iloc[0]
446+
f"{method}_step2_best_yaml"].iloc[0]
408447
atlas_dataset_truth = ground_truth_conf.loc[ground_truth_conf["dataset_id"] == self.adata2_name,
409-
f"{method}_best_yaml"].iloc[0]
448+
f"{method}_step2_best_yaml"].iloc[0]
410449
if type(atlas_dataset_truth) == float and np.isnan(atlas_dataset_truth):
411450
return 0
412451
query_targets = get_targets(query_dataset_truth)

dance/settings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def change_log_level(name: str = "dance", /, *, level: Union[str, int]):
4747
DANCEDIR = Path(__file__).resolve().parents[1]
4848
DANCEPKGDIR = DANCEDIR / "dance"
4949
METADIR = DANCEPKGDIR / "metadata"
50-
50+
ATLASDIR = DANCEDIR / "examples/atlas"
51+
SIMILARITYDIR = ATLASDIR / "sc_similarity_examples"
52+
entity = "xzy11632"
53+
project = "dance-dev"
5154
__all__ = [
5255
"change_log_level",
5356
]

examples/atlas/get_result_web.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tqdm import tqdm
1313

1414
from dance import logger
15-
from dance.settings import DANCEDIR, METADIR
15+
from dance.settings import ATLASDIR, DANCEDIR, METADIR
1616
from dance.utils import try_import
1717

1818
# get yaml of best method

0 commit comments

Comments
 (0)