|
13 | 13 | par = { |
14 | 14 | "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", |
15 | 15 | "output": "output.h5ad", |
16 | | - "model_name": "large", |
| 16 | + "model_name": "v2-medium", |
17 | 17 | "model": None, |
18 | 18 | } |
19 | 19 | meta = {"name": "scprint"} |
|
30 | 30 |
|
31 | 31 | print("\n>>> Reading input data...", flush=True) |
32 | 32 | input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") |
33 | | -if input.uns["dataset_organism"] == "homo_sapiens": |
34 | | - input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" |
35 | | -elif input.uns["dataset_organism"] == "mus_musculus": |
36 | | - input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" |
37 | | -else: |
38 | | - exit_non_applicable( |
39 | | - f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" |
40 | | - ) |
| 33 | +if ( |
| 34 | + "organism_ontology_term_id" not in input.obs.columns |
| 35 | + and "dataset_organism" in input.uns |
| 36 | +): |
| 37 | + if input.uns["dataset_organism"] == "homo_sapiens": |
| 38 | + input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" |
| 39 | + elif input.uns["dataset_organism"] == "mus_musculus": |
| 40 | + input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" |
| 41 | + else: |
| 42 | + exit_non_applicable( |
| 43 | + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" |
| 44 | + ) |
41 | 45 | adata = input.copy() |
42 | 46 |
|
43 | 47 | print("\n>>> Preprocessing data...", flush=True) |
|
59 | 63 | repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" |
60 | 64 | ) |
61 | 65 |
|
62 | | -print("\n>>> Embedding data...", flush=True) |
63 | 66 | if torch.cuda.is_available(): |
64 | 67 | print("CUDA is available, using GPU", flush=True) |
65 | 68 | precision = "16" |
66 | 69 | dtype = torch.float16 |
67 | | - transformer="flash" |
| 70 | + transformer = "flash" |
68 | 71 | else: |
69 | 72 | print("CUDA is not available, using CPU", flush=True) |
70 | 73 | precision = "32" |
71 | 74 | dtype = torch.float32 |
72 | | - transformer="normal" |
| 75 | + transformer = "normal" |
73 | 76 |
|
74 | 77 | print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) |
75 | | -model = scPrint.load_from_checkpoint( |
76 | | - model_checkpoint_file, |
77 | | - transformer=transformer, # Don't use this for GPUs with flashattention |
78 | | - precpt_gene_emb=None, |
79 | | -) |
80 | 78 |
|
| 79 | +m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) |
| 80 | +if "label_counts" in m["hyper_parameters"]: |
| 81 | + model = scPrint.load_from_checkpoint( |
| 82 | + model_checkpoint_file, |
| 83 | + transformer=transformer, # Don't use this for GPUs with flashattention |
| 84 | + precpt_gene_emb=None, |
| 85 | + classes=m["hyper_parameters"]["label_counts"], |
| 86 | + ) |
| 87 | +else: |
| 88 | + model = scPrint.load_from_checkpoint( |
| 89 | + model_checkpoint_file, |
| 90 | + transformer=transformer, # Don't use this for GPUs with flashattention |
| 91 | + precpt_gene_emb=None, |
| 92 | + ) |
| 93 | +del m |
| 94 | + |
| 95 | +print("\n>>> Embedding data...", flush=True) |
81 | 96 | n_cores = min(len(os.sched_getaffinity(0)), 24) |
82 | 97 | print(f"Using {n_cores} worker cores") |
83 | 98 | embedder = Embedder( |
|
91 | 106 | pred_embedding=["cell_type_ontology_term_id"], |
92 | 107 | keep_all_cls_pred=False, |
93 | 108 | output_expression="none", |
| 109 | + save_every=30_000, |
94 | 110 | precision=precision, |
95 | 111 | dtype=dtype, |
96 | 112 | ) |
|
101 | 117 | obs=input.obs[[]], |
102 | 118 | var=input.var[[]], |
103 | 119 | obsm={ |
104 | | - "X_emb": embedded.obsm["scprint"], |
| 120 | + "X_emb": embedded.obsm["scprint_emb"], |
105 | 121 | }, |
106 | 122 | uns={ |
107 | 123 | "dataset_id": input.uns["dataset_id"], |
|
0 commit comments