diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml index 20760aa3..2b949cbb 100644 --- a/src/methods/scgpt_finetuned/config.vsh.yaml +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -51,13 +51,20 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies + #- type: python + # pypi: + # - gdown + # - scgpt # Install from PyPI to get dependencies + #- type: docker + # # Force re-installing from GitHub to get bug fixes + # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git - type: docker - # Force re-installing from GitHub to get bug fixes - run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + run: | + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps runners: - type: executable diff --git a/src/methods/scgpt_zeroshot/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml index ba2455c6..3ff6425c 100644 --- a/src/methods/scgpt_zeroshot/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -53,13 +53,20 @@ engines: image: openproblems/base_pytorch_nvidia:1 # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies + #- type: python + # pypi: + # - gdown + # - scgpt # Install from PyPI to get dependencies + #- type: docker + # # Force re-installing from GitHub to get bug fixes + # run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git - type: docker - # Force re-installing from GitHub to get bug fixes - run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + run: | + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps runners: - type: executable diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 83e3f9a8..0e3020e5 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -75,8 +75,7 @@ engines: setup: - type: python pip: - - scprint>=2.3.0 - - gseapy>=1.1.8 + - scprint==2.3.5 - type: docker run: | lamin init --storage ./main --name main --schema bionty && \ @@ -87,6 +86,7 @@ engines: script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() runners: - type: executable + # docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 2342875a..46297afa 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -6,6 +6,7 @@ import torch from huggingface_hub import hf_hub_download from scdataloader import Preprocessor +from scdataloader.utils import load_genes from scprint import scPrint from scprint.tasks import Embedder @@ -63,34 +64,55 @@ repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" ) +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) + if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) - precision = "16" - dtype = torch.float16 transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) - precision = "32" - dtype = torch.float32 transformer = "normal" -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) +try: + m = torch.load(model_checkpoint_file) +# if not use this instead since the model weights are by default mapped to GPU types +except RuntimeError: + m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) -m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer +if "prenorm" in m["hyper_parameters"]: + m["hyper_parameters"].pop("prenorm") + torch.save(m, model_checkpoint_file) if "label_counts" in m["hyper_parameters"]: + # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention precpt_gene_emb=None, classes=m["hyper_parameters"]["label_counts"], + transformer=transformer, ) else: model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention - precpt_gene_emb=None, + model_checkpoint_file, precpt_gene_emb=None, transformer=transformer ) del m +# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology +missing = set(model.genes) - set(load_genes(model.organisms).index) +if len(missing) > 0: + print( + "Warning: some genes missmatch exist between model and ontology: solving...", + ) + model._rm_genes(missing) + +# again if not on GPU you need to convert the model to float32 +if not torch.cuda.is_available(): + model = model.to(torch.float32) + +# you can perform your inference on float16 if you have a GPU, otherwise use float64 +dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device +model = model.to("cuda" if torch.cuda.is_available() else "cpu") print("\n>>> Embedding data...", flush=True) n_cores = min(len(os.sched_getaffinity(0)), 24) @@ -107,7 +129,6 @@ keep_all_cls_pred=False, output_expression="none", save_every=30_000, - precision=precision, dtype=dtype, ) embedded, _ = embedder(model, adata, cache=False)