Skip to content

Commit 50141ae

Browse files
committed
update user interface
1 parent 315b532 commit 50141ae

File tree

2 files changed

+167
-0
lines changed

2 files changed

+167
-0
lines changed

examples/atlas/demos/main.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# This script finds the most similar dataset from the atlas for a given user-uploaded dataset
2+
# It calculates similarity scores and returns the best matching dataset along with its configurations
3+
4+
import argparse
5+
import json
6+
7+
import pandas as pd
8+
import scanpy as sc
9+
10+
from dance import logger
11+
from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata
12+
from dance.settings import DANCEDIR, SIMILARITYDIR
13+
14+
15+
def calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query):
16+
"""Calculate similarity scores between source data and atlas datasets.
17+
18+
Args:
19+
source_data: User uploaded AnnData object
20+
tissue: Target tissue type
21+
atlas_datasets: List of candidate datasets from atlas
22+
reduce_error: Flag for error reduction mode - when True, applies a significant penalty
23+
to configurations in the atlas that produced errors
24+
in_query: Flag for query mode - when True, ranks similarity based on query performance,
25+
when False, ranks based on inter-atlas comparison
26+
27+
Returns:
28+
Dictionary containing similarity scores for each atlas dataset
29+
30+
"""
31+
with open(
32+
SIMILARITYDIR /
33+
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json",
34+
encoding='utf-8') as f:
35+
sim_dict = json.load(f)
36+
feature_name = sim_dict[tissue]["feature_name"]
37+
w1 = sim_dict[tissue]["weight1"]
38+
w2 = 1 - w1
39+
ans = {}
40+
for target_file in atlas_datasets:
41+
logger.info(f"calculating similarity for {target_file}")
42+
atlas_data = get_anndata(tissue=tissue.capitalize(), species="human", filetype="h5ad",
43+
train_dataset=[f"{target_file}"], data_dir=str(DANCEDIR / "examples/tuning/temp_data"))
44+
similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=atlas_data, sample_size=10,
45+
init_random_state=42, n_runs=1, tissue=tissue)
46+
sim_target = similarity_calculator.get_similarity_matrix_A2B(methods=[feature_name, "metadata_sim"])
47+
ans[target_file] = sim_target[feature_name] * w1 + sim_target["metadata_sim"] * w2
48+
return ans
49+
50+
51+
def main(args):
52+
"""Main function to process user data and find the most similar atlas dataset.
53+
54+
Args:
55+
args: Arguments containing:
56+
- tissue: Target tissue type
57+
- data_dir: Directory containing the source data
58+
- source_file: Name of the source file
59+
60+
Returns:
61+
tuple containing:
62+
- ans_file: ID of the most similar dataset
63+
- ans_conf: Preprocess configuration dictionary for different cell type annotation methods
64+
- ans_value: Similarity score of the best matching dataset
65+
66+
"""
67+
reduce_error = False
68+
in_query = True
69+
tissue = args.tissue
70+
tissue = tissue.lower()
71+
conf_data = pd.read_excel(SIMILARITYDIR / "data/Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
72+
atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"])
73+
source_data = sc.read_h5ad(f"{args.data_dir}/{args.source_file}.h5ad")
74+
75+
ans = calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query)
76+
ans_file = max(ans, key=ans.get)
77+
ans_value = ans[ans_file]
78+
ans_conf = {
79+
method: conf_data.loc[conf_data["dataset_id"] == ans_file, f"{method}_step2_best_yaml"].iloc[0]
80+
for method in ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"]
81+
}
82+
return ans_file, ans_conf, ans_value
83+
84+
85+
if __name__ == "__main__":
86+
parser = argparse.ArgumentParser()
87+
parser.add_argument("--tissue", default="Brain")
88+
parser.add_argument("--data_dir", default=str(DANCEDIR / "examples/tuning/temp_data/train/human"))
89+
parser.add_argument("--source_file", default="human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data")
90+
91+
args = parser.parse_args()
92+
ans_file, ans_conf, ans_value = main(args)
93+
print(ans_file, ans_conf, ans_value)

tests/atlas/test_atlas.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
Test suite for the Atlas similarity calculation functionality.
3+
This test verifies that the main function correctly returns:
4+
1. The most similar dataset from the atlas
5+
2. Its corresponding configuration settings
6+
3. The similarity score
7+
8+
The test ensures:
9+
- Return value types are correct
10+
- Similarity score is within valid range (0-1)
11+
- Configuration dictionary contains all required cell type annotation methods
12+
"""
13+
14+
import json
15+
import sys
16+
17+
import pandas as pd
18+
19+
from dance.settings import ATLASDIR, DANCEDIR, SIMILARITYDIR
20+
21+
sys.path.append(str(ATLASDIR))
22+
from demos.main import main
23+
24+
from dance import logger
25+
26+
27+
def test_main():
28+
# Construct test parameters with a sample Brain tissue dataset
29+
class Args:
30+
tissue = "Brain"
31+
data_dir = str(DANCEDIR / "examples/tuning/temp_data/train/human")
32+
source_file = "human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data"
33+
34+
args = Args()
35+
logger.info(f"testing main with args: {args}")
36+
source_id = "3643"
37+
38+
# Execute main function with test parameters
39+
ans_file, ans_conf, ans_value = main(args)
40+
41+
# Verify return value types and ranges
42+
assert isinstance(ans_file, str), "ans_file should be a string type"
43+
assert isinstance(ans_value, float), "ans_value should be a float type"
44+
assert 0 <= ans_value <= 1, "Similarity value should be between 0 and 1"
45+
46+
# Verify configuration dictionary structure and content
47+
expected_methods = ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"]
48+
assert isinstance(ans_conf, dict), "ans_conf should be a dictionary type"
49+
assert set(ans_conf.keys()) == set(expected_methods), "ans_conf should contain all expected methods"
50+
assert all(isinstance(v, str) for v in ans_conf.values()), "All configuration values should be string type"
51+
52+
# Verify consistency with Excel spreadsheet results
53+
data = pd.read_excel(SIMILARITYDIR / f"data/new_sim/{args.tissue.lower()}_similarity.xlsx", sheet_name=source_id,
54+
index_col=0)
55+
reduce_error = False
56+
in_query = True
57+
# Read weights
58+
with open(
59+
SIMILARITYDIR /
60+
f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json",
61+
encoding='utf-8') as f:
62+
sim_dict = json.load(f)
63+
feature_name = sim_dict[args.tissue.lower()]["feature_name"]
64+
w1 = sim_dict[args.tissue.lower()]["weight1"]
65+
w2 = 1 - w1
66+
67+
# Calculate similarity in Excel
68+
data.loc["similarity"] = data.loc[feature_name] * w1 + data.loc["metadata_sim"] * w2
69+
expected_file = data.loc["similarity"].idxmax()
70+
expected_value = data.loc["similarity", expected_file]
71+
72+
# Verify result consistency with Excel
73+
assert abs(ans_value - expected_value) < 1e-4, "Calculated similarity value does not match Excel value"
74+
assert ans_file == expected_file, "Selected most similar dataset does not match Excel result"

0 commit comments

Comments
 (0)