Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,23 @@ info:
preferred_normalization: counts
variants:
scprint_large:
model_name: "large"
model_name: "large-v1"
scprint_medium:
model_name: "v2-medium"
model_name: "medium-v1.5"
scprint_small:
model_name: "small"
model_name: "small-v1"
test_setup:
run:
model_name: small
model_name: small-v1
batch_size: 16
max_len: 100

arguments:
- name: "--model_name"
type: "string"
description: Which model to use. Not used if --model is provided.
choices: ["large", "v2-medium", "small"]
default: "v2-medium"
choices: ["large-v1", "medium-v1.5", "small-v1"]
default: "medium-v1.5"
- name: --model
type: file
description: Path to the scPRINT model.
Expand All @@ -64,14 +64,14 @@ arguments:
- name: --max_len
type: integer
description: The maximum length of the gene sequence.
default: 4000
default: 2000
- name: --infer_matches
type: string
description:
The method to use to infer the matches between the predicted and
true labels.
choices: ["direct", "linear_sum_assignment"]
default: "linear_sum_assignment"
default: "direct"

resources:
- type: python_script
Expand All @@ -84,8 +84,7 @@ engines:
setup:
- type: python
pip:
- scprint>=2.3.0
- gseapy>=1.1.8
- scprint==2.3.5
- type: docker
run: |
lamin init --storage ./main --name main --schema bionty && \
Expand All @@ -97,6 +96,7 @@ engines:

runners:
- type: executable
# docker_run_args: --gpus all
- type: nextflow
directives:
label: [hightime, highmem, midcpu, gpu, highsharedmem]
52 changes: 37 additions & 15 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scdataloader.utils import load_genes
from scipy.optimize import linear_sum_assignment
from scipy.spatial import distance
from scprint import scPrint
Expand All @@ -16,9 +17,9 @@
"input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad",
"input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad",
"output": "output.h5ad",
"model_name": "v2-medium",
"model_name": "medium-v1.5",
"model": None,
"infer_matches": "linear_sum_assignment",
"infer_matches": "direct",
}
meta = {"name": "scprint"}
## VIASH END
Expand All @@ -34,7 +35,9 @@
print("\n>>> Reading input data...", flush=True)
input_train = ad.read_h5ad(par["input_train"])
input_test = ad.read_h5ad(par["input_test"])

input_test_uns = input_test.uns.copy()
input_test_obs = input_test.obs.copy()

print("\n>>> Preprocessing input data...", flush=True)
# store organism ontology term id
Expand Down Expand Up @@ -85,32 +88,52 @@

if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16"
dtype = torch.float16
transformer = "flash"
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
transformer = "normal"

print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)

m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
# make sure that you check if you have a GPU with flashattention or not (see README)
try:
m = torch.load(model_checkpoint_file)
# if not use this instead since the model weights are by default mapped to GPU types
except RuntimeError:
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))

# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
if "prenorm" in m["hyper_parameters"]:
m["hyper_parameters"].pop("prenorm")
torch.save(m, model_checkpoint_file)
if "label_counts" in m["hyper_parameters"]:
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
classes=m["hyper_parameters"]["label_counts"],
transformer=transformer,
)
else:
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
)
del m
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
missing = set(model.genes) - set(load_genes(model.organisms).index)
if len(missing) > 0:
print(
"Warning: some genes missmatch exist between model and ontology: solving...",
)
model._rm_genes(missing)

# again if not on GPU you need to convert the model to float64
if not torch.cuda.is_available():
model = model.to(torch.float32)

# you can perform your inference on float16 if you have a GPU, otherwise use float64
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

n_cores = max(16, len(os.sched_getaffinity(0)))

Expand All @@ -123,7 +146,6 @@
num_workers=n_cores,
doclass=True,
doplot=False,
precision=precision,
dtype=dtype,
pred_embedding=["cell_type_ontology_term_id"],
keep_all_cls_pred=False,
Expand Down Expand Up @@ -217,7 +239,6 @@
num_workers=n_cores,
doclass=True,
doplot=False,
precision=precision,
dtype=dtype,
pred_embedding=["cell_type_ontology_term_id"],
keep_all_cls_pred=False,
Expand All @@ -230,6 +251,7 @@
"conv_pred_cell_type_ontology_term_id"
].values
input_test.obs = input_test.obs.replace(dict(label_pred=matches))
input_test.obs.index = input_test_obs.index

print("\n>>> Storing output...", flush=True)
output = ad.AnnData(
Expand Down
Loading