diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 6b2321f..2bd56c1 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -36,14 +36,14 @@ info: preferred_normalization: counts variants: scprint_large: - model_name: "large" + model_name: "large-v1" scprint_medium: - model_name: "v2-medium" + model_name: "medium-v1.5" scprint_small: - model_name: "small" + model_name: "small-v1" test_setup: run: - model_name: small + model_name: small-v1 batch_size: 16 max_len: 100 @@ -51,8 +51,8 @@ arguments: - name: "--model_name" type: "string" description: Which model to use. Not used if --model is provided. - choices: ["large", "v2-medium", "small"] - default: "v2-medium" + choices: ["large-v1", "medium-v1.5", "small-v1"] + default: "medium-v1.5" - name: --model type: file description: Path to the scPRINT model. @@ -64,14 +64,14 @@ arguments: - name: --max_len type: integer description: The maximum length of the gene sequence. - default: 4000 + default: 2000 - name: --infer_matches type: string description: The method to use to infer the matches between the predicted and true labels. choices: ["direct", "linear_sum_assignment"] - default: "linear_sum_assignment" + default: "direct" resources: - type: python_script @@ -84,8 +84,7 @@ engines: setup: - type: python pip: - - scprint>=2.3.0 - - gseapy>=1.1.8 + - scprint==2.3.5 - type: docker run: | lamin init --storage ./main --name main --schema bionty && \ @@ -97,6 +96,7 @@ engines: runners: - type: executable + # docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 7881a90..2994d8e 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -7,6 +7,7 @@ import torch from huggingface_hub import hf_hub_download from scdataloader import Preprocessor +from scdataloader.utils import load_genes from scipy.optimize import linear_sum_assignment from scipy.spatial import distance from scprint import scPrint @@ -16,9 +17,9 @@ "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", "output": "output.h5ad", - "model_name": "v2-medium", + "model_name": "medium-v1.5", "model": None, - "infer_matches": "linear_sum_assignment", + "infer_matches": "direct", } meta = {"name": "scprint"} ## VIASH END @@ -34,7 +35,9 @@ print("\n>>> Reading input data...", flush=True) input_train = ad.read_h5ad(par["input_train"]) input_test = ad.read_h5ad(par["input_test"]) + input_test_uns = input_test.uns.copy() +input_test_obs = input_test.obs.copy() print("\n>>> Preprocessing input data...", flush=True) # store organism ontology term id @@ -85,32 +88,52 @@ if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) - precision = "16" - dtype = torch.float16 transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) - precision = "32" - dtype = torch.float32 transformer = "normal" -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) - -m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +# make sure that you check if you have a GPU with flashattention or not (see README) +try: + m = torch.load(model_checkpoint_file) +# if not use this instead since the model weights are by default mapped to GPU types +except RuntimeError: + m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) + +# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer +if "prenorm" in m["hyper_parameters"]: + m["hyper_parameters"].pop("prenorm") + torch.save(m, model_checkpoint_file) if "label_counts" in m["hyper_parameters"]: + # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention precpt_gene_emb=None, classes=m["hyper_parameters"]["label_counts"], + transformer=transformer, ) else: model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention - precpt_gene_emb=None, + model_checkpoint_file, precpt_gene_emb=None, transformer=transformer ) del m +# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology +missing = set(model.genes) - set(load_genes(model.organisms).index) +if len(missing) > 0: + print( + "Warning: some genes missmatch exist between model and ontology: solving...", + ) + model._rm_genes(missing) + +# again if not on GPU you need to convert the model to float64 +if not torch.cuda.is_available(): + model = model.to(torch.float32) + +# you can perform your inference on float16 if you have a GPU, otherwise use float64 +dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device +model = model.to("cuda" if torch.cuda.is_available() else "cpu") n_cores = max(16, len(os.sched_getaffinity(0))) @@ -123,7 +146,6 @@ num_workers=n_cores, doclass=True, doplot=False, - precision=precision, dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, @@ -217,7 +239,6 @@ num_workers=n_cores, doclass=True, doplot=False, - precision=precision, dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, @@ -230,6 +251,7 @@ "conv_pred_cell_type_ontology_term_id" ].values input_test.obs = input_test.obs.replace(dict(label_pred=matches)) +input_test.obs.index = input_test_obs.index print("\n>>> Storing output...", flush=True) output = ad.AnnData(