Skip to content

Commit a75dc6e

Browse files
jkobjectlazappi
andauthored
Update scPRINT to handle large datasets (#54)
* dbug scprint * allowing flash attn * Update _viash.yaml * Update CHANGELOG * adding some debug * better model loading and new model * final debug * better now * finish debug * ending tests successfully * removing flag * new dataloader version * Update CHANGELOG --------- Co-authored-by: Luke Zappia <[email protected]> Co-authored-by: Luke Zappia <[email protected]>
1 parent 81856f1 commit a75dc6e

File tree

3 files changed

+43
-24
lines changed

3 files changed

+43
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Minor changes
44

55
* Un-pin the scPRINT version and update parameters (PR #51)
6+
* Update scPRINT to better handle large datasets, including a new default model (PR #54)
67

78
# task_batch_integration 2.0.0
89

src/methods/scprint/config.vsh.yaml

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ info:
3535
scprint_large:
3636
model_name: "large"
3737
scprint_medium:
38-
model_name: "medium"
38+
model_name: "v2-medium"
3939
scprint_small:
4040
model_name: "small"
4141
test_setup:
@@ -48,8 +48,8 @@ arguments:
4848
- name: "--model_name"
4949
type: "string"
5050
description: Which model to use. Not used if --model is provided.
51-
choices: ["large", "medium", "small"]
52-
default: "large"
51+
choices: ["large", "v2-medium", "small"]
52+
default: "v2-medium"
5353
- name: --model
5454
type: file
5555
description: Path to the scPRINT model.
@@ -75,15 +75,17 @@ engines:
7575
setup:
7676
- type: python
7777
pip:
78-
- scprint
78+
- git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b
79+
- gseapy==1.1.2
80+
- git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa
7981
- type: docker
8082
run: lamin init --storage ./main --name main --schema bionty
8183
- type: docker
8284
run: lamin load anonymous/main
83-
- type: python
84-
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()
8585
- type: python
8686
script: import bionty as bt; bt.core.sync_all_sources_to_latest()
87+
- type: python
88+
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()
8789
runners:
8890
- type: executable
8991
- type: nextflow

src/methods/scprint/script.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
par = {
1414
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
1515
"output": "output.h5ad",
16-
"model_name": "large",
16+
"model_name": "v2-medium",
1717
"model": None,
1818
}
1919
meta = {"name": "scprint"}
@@ -30,14 +30,18 @@
3030

3131
print("\n>>> Reading input data...", flush=True)
3232
input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")
33-
if input.uns["dataset_organism"] == "homo_sapiens":
34-
input.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
35-
elif input.uns["dataset_organism"] == "mus_musculus":
36-
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
37-
else:
38-
exit_non_applicable(
39-
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
40-
)
33+
if (
34+
"organism_ontology_term_id" not in input.obs.columns
35+
and "dataset_organism" in input.uns
36+
):
37+
if input.uns["dataset_organism"] == "homo_sapiens":
38+
input.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
39+
elif input.uns["dataset_organism"] == "mus_musculus":
40+
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
41+
else:
42+
exit_non_applicable(
43+
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
44+
)
4145
adata = input.copy()
4246

4347
print("\n>>> Preprocessing data...", flush=True)
@@ -59,25 +63,36 @@
5963
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
6064
)
6165

62-
print("\n>>> Embedding data...", flush=True)
6366
if torch.cuda.is_available():
6467
print("CUDA is available, using GPU", flush=True)
6568
precision = "16"
6669
dtype = torch.float16
67-
transformer="flash"
70+
transformer = "flash"
6871
else:
6972
print("CUDA is not available, using CPU", flush=True)
7073
precision = "32"
7174
dtype = torch.float32
72-
transformer="normal"
75+
transformer = "normal"
7376

7477
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
75-
model = scPrint.load_from_checkpoint(
76-
model_checkpoint_file,
77-
transformer=transformer, # Don't use this for GPUs with flashattention
78-
precpt_gene_emb=None,
79-
)
8078

79+
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
80+
if "label_counts" in m["hyper_parameters"]:
81+
model = scPrint.load_from_checkpoint(
82+
model_checkpoint_file,
83+
transformer=transformer, # Don't use this for GPUs with flashattention
84+
precpt_gene_emb=None,
85+
classes=m["hyper_parameters"]["label_counts"],
86+
)
87+
else:
88+
model = scPrint.load_from_checkpoint(
89+
model_checkpoint_file,
90+
transformer=transformer, # Don't use this for GPUs with flashattention
91+
precpt_gene_emb=None,
92+
)
93+
del m
94+
95+
print("\n>>> Embedding data...", flush=True)
8196
n_cores = min(len(os.sched_getaffinity(0)), 24)
8297
print(f"Using {n_cores} worker cores")
8398
embedder = Embedder(
@@ -91,6 +106,7 @@
91106
pred_embedding=["cell_type_ontology_term_id"],
92107
keep_all_cls_pred=False,
93108
output_expression="none",
109+
save_every=30_000,
94110
precision=precision,
95111
dtype=dtype,
96112
)
@@ -101,7 +117,7 @@
101117
obs=input.obs[[]],
102118
var=input.var[[]],
103119
obsm={
104-
"X_emb": embedded.obsm["scprint"],
120+
"X_emb": embedded.obsm["scprint_emb"],
105121
},
106122
uns={
107123
"dataset_id": input.uns["dataset_id"],

0 commit comments

Comments
 (0)