|
| 1 | +# This script finds the most similar dataset from the atlas for a given user-uploaded dataset |
| 2 | +# It calculates similarity scores and returns the best matching dataset along with its configurations |
| 3 | + |
| 4 | +import argparse |
| 5 | +import json |
| 6 | + |
| 7 | +import pandas as pd |
| 8 | +import scanpy as sc |
| 9 | + |
| 10 | +from dance import logger |
| 11 | +from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata |
| 12 | +from dance.settings import DANCEDIR, SIMILARITYDIR |
| 13 | + |
| 14 | + |
| 15 | +def calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query): |
| 16 | + """Calculate similarity scores between source data and atlas datasets. |
| 17 | +
|
| 18 | + Args: |
| 19 | + source_data: User uploaded AnnData object |
| 20 | + tissue: Target tissue type |
| 21 | + atlas_datasets: List of candidate datasets from atlas |
| 22 | + reduce_error: Flag for error reduction mode - when True, applies a significant penalty |
| 23 | + to configurations in the atlas that produced errors |
| 24 | + in_query: Flag for query mode - when True, ranks similarity based on query performance, |
| 25 | + when False, ranks based on inter-atlas comparison |
| 26 | +
|
| 27 | + Returns: |
| 28 | + Dictionary containing similarity scores for each atlas dataset |
| 29 | +
|
| 30 | + """ |
| 31 | + with open( |
| 32 | + SIMILARITYDIR / |
| 33 | + f"data/similarity_weights_results/{'reduce_error_' if reduce_error else ''}{'in_query_' if in_query else ''}sim_dict.json", |
| 34 | + encoding='utf-8') as f: |
| 35 | + sim_dict = json.load(f) |
| 36 | + feature_name = sim_dict[tissue]["feature_name"] |
| 37 | + w1 = sim_dict[tissue]["weight1"] |
| 38 | + w2 = 1 - w1 |
| 39 | + ans = {} |
| 40 | + for target_file in atlas_datasets: |
| 41 | + logger.info(f"calculating similarity for {target_file}") |
| 42 | + atlas_data = get_anndata(tissue=tissue.capitalize(), species="human", filetype="h5ad", |
| 43 | + train_dataset=[f"{target_file}"], data_dir=str(DANCEDIR / "examples/tuning/temp_data")) |
| 44 | + similarity_calculator = AnnDataSimilarity(adata1=source_data, adata2=atlas_data, sample_size=10, |
| 45 | + init_random_state=42, n_runs=1, tissue=tissue) |
| 46 | + sim_target = similarity_calculator.get_similarity_matrix_A2B(methods=[feature_name, "metadata_sim"]) |
| 47 | + ans[target_file] = sim_target[feature_name] * w1 + sim_target["metadata_sim"] * w2 |
| 48 | + return ans |
| 49 | + |
| 50 | + |
| 51 | +def main(args): |
| 52 | + """Main function to process user data and find the most similar atlas dataset. |
| 53 | +
|
| 54 | + Args: |
| 55 | + args: Arguments containing: |
| 56 | + - tissue: Target tissue type |
| 57 | + - data_dir: Directory containing the source data |
| 58 | + - source_file: Name of the source file |
| 59 | +
|
| 60 | + Returns: |
| 61 | + tuple containing: |
| 62 | + - ans_file: ID of the most similar dataset |
| 63 | + - ans_conf: Preprocess configuration dictionary for different cell type annotation methods |
| 64 | + - ans_value: Similarity score of the best matching dataset |
| 65 | +
|
| 66 | + """ |
| 67 | + reduce_error = False |
| 68 | + in_query = True |
| 69 | + tissue = args.tissue |
| 70 | + tissue = tissue.lower() |
| 71 | + conf_data = pd.read_excel(SIMILARITYDIR / "data/Cell Type Annotation Atlas.xlsx", sheet_name=tissue) |
| 72 | + atlas_datasets = list(conf_data[conf_data["queryed"] == False]["dataset_id"]) |
| 73 | + source_data = sc.read_h5ad(f"{args.data_dir}/{args.source_file}.h5ad") |
| 74 | + |
| 75 | + ans = calculate_similarity(source_data, tissue, atlas_datasets, reduce_error, in_query) |
| 76 | + ans_file = max(ans, key=ans.get) |
| 77 | + ans_value = ans[ans_file] |
| 78 | + ans_conf = { |
| 79 | + method: conf_data.loc[conf_data["dataset_id"] == ans_file, f"{method}_step2_best_yaml"].iloc[0] |
| 80 | + for method in ["cta_celltypist", "cta_scdeepsort", "cta_singlecellnet", "cta_actinn"] |
| 81 | + } |
| 82 | + return ans_file, ans_conf, ans_value |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + parser = argparse.ArgumentParser() |
| 87 | + parser.add_argument("--tissue", default="Brain") |
| 88 | + parser.add_argument("--data_dir", default=str(DANCEDIR / "examples/tuning/temp_data/train/human")) |
| 89 | + parser.add_argument("--source_file", default="human_Brain364348b4-bc34-4fe1-a851-60d99e36cafa_data") |
| 90 | + |
| 91 | + args = parser.parse_args() |
| 92 | + ans_file, ans_conf, ans_value = main(args) |
| 93 | + print(ans_file, ans_conf, ans_value) |
0 commit comments