Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/methods/scgpt_finetuned/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,20 @@ engines:
image: openproblems/base_pytorch_nvidia:1
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
#- type: python
# pypi:
# - gdown
# - scgpt # Install from PyPI to get dependencies
#- type: docker
# # Force re-installing from GitHub to get bug fixes
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
run: |
git clone https://github.com/bowang-lab/scGPT && \
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
pip install "flash-attn<1.0.5" --no-build-isolation && \
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
cd scGPT && pip install -e . --no-deps

runners:
- type: executable
Expand Down
19 changes: 13 additions & 6 deletions src/methods/scgpt_zeroshot/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,20 @@ engines:
image: openproblems/base_pytorch_nvidia:1
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
#- type: python
# pypi:
# - gdown
# - scgpt # Install from PyPI to get dependencies
#- type: docker
# # Force re-installing from GitHub to get bug fixes
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
run: |
git clone https://github.com/bowang-lab/scGPT && \
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
pip install "flash-attn<1.0.5" --no-build-isolation && \
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
cd scGPT && pip install -e . --no-deps

runners:
- type: executable
Expand Down
4 changes: 2 additions & 2 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ engines:
setup:
- type: python
pip:
- scprint>=2.3.0
- gseapy>=1.1.8
- scprint==2.3.5
- type: docker
run: |
lamin init --storage ./main --name main --schema bionty && \
Expand All @@ -87,6 +86,7 @@ engines:
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()
runners:
- type: executable
# docker_run_args: --gpus all
- type: nextflow
directives:
label: [hightime, highmem, midcpu, gpu, highsharedmem]
43 changes: 32 additions & 11 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scdataloader.utils import load_genes
from scprint import scPrint
from scprint.tasks import Embedder

Expand Down Expand Up @@ -63,34 +64,55 @@
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
)

print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)

if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16"
dtype = torch.float16
transformer = "flash"
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
transformer = "normal"

print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
try:
m = torch.load(model_checkpoint_file)
# if not use this instead since the model weights are by default mapped to GPU types
except RuntimeError:
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))

m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
if "prenorm" in m["hyper_parameters"]:
m["hyper_parameters"].pop("prenorm")
torch.save(m, model_checkpoint_file)
if "label_counts" in m["hyper_parameters"]:
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
classes=m["hyper_parameters"]["label_counts"],
transformer=transformer,
)
else:
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
)
del m
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
missing = set(model.genes) - set(load_genes(model.organisms).index)
if len(missing) > 0:
print(
"Warning: some genes missmatch exist between model and ontology: solving...",
)
model._rm_genes(missing)

# again if not on GPU you need to convert the model to float32
if not torch.cuda.is_available():
model = model.to(torch.float32)

# you can perform your inference on float16 if you have a GPU, otherwise use float64
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

print("\n>>> Embedding data...", flush=True)
n_cores = min(len(os.sched_getaffinity(0)), 24)
Expand All @@ -107,7 +129,6 @@
keep_all_cls_pred=False,
output_expression="none",
save_every=30_000,
precision=precision,
dtype=dtype,
)
embedded, _ = embedder(model, adata, cache=False)
Expand Down