Skip to content

Commit 1c34b59

Browse files
jkobjectlazappi
andauthored
Update scprint (#71)
* dbug scprint * allowing flash attn * Update _viash.yaml * Update CHANGELOG * adding some debug * better model loading and new model * final debug * better now * finish debug * ending tests successfully * removing flag * new dataloader version * Update CHANGELOG * solving some issues * update scprint * Update src/methods/scprint/script.py Co-authored-by: Luke Zappia <[email protected]> * Update src/methods/scprint/script.py * improve the scgpt installation (now uses flash attention) --------- Co-authored-by: Luke Zappia <[email protected]> Co-authored-by: Luke Zappia <[email protected]>
1 parent 37df0e7 commit 1c34b59

File tree

4 files changed

+60
-25
lines changed

4 files changed

+60
-25
lines changed

src/methods/scgpt_finetuned/config.vsh.yaml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,20 @@ engines:
5151
image: openproblems/base_pytorch_nvidia:1
5252
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
5353
setup:
54-
- type: python
55-
pypi:
56-
- gdown
57-
- scgpt # Install from PyPI to get dependencies
54+
#- type: python
55+
# pypi:
56+
# - gdown
57+
# - scgpt # Install from PyPI to get dependencies
58+
#- type: docker
59+
# # Force re-installing from GitHub to get bug fixes
60+
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
5861
- type: docker
59-
# Force re-installing from GitHub to get bug fixes
60-
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
62+
run: |
63+
git clone https://github.com/bowang-lab/scGPT && \
64+
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
65+
pip install "flash-attn<1.0.5" --no-build-isolation && \
66+
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
67+
cd scGPT && pip install -e . --no-deps
6168
6269
runners:
6370
- type: executable

src/methods/scgpt_zeroshot/config.vsh.yaml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,20 @@ engines:
5353
image: openproblems/base_pytorch_nvidia:1
5454
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
5555
setup:
56-
- type: python
57-
pypi:
58-
- gdown
59-
- scgpt # Install from PyPI to get dependencies
56+
#- type: python
57+
# pypi:
58+
# - gdown
59+
# - scgpt # Install from PyPI to get dependencies
60+
#- type: docker
61+
# # Force re-installing from GitHub to get bug fixes
62+
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
6063
- type: docker
61-
# Force re-installing from GitHub to get bug fixes
62-
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
64+
run: |
65+
git clone https://github.com/bowang-lab/scGPT && \
66+
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
67+
pip install "flash-attn<1.0.5" --no-build-isolation && \
68+
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
69+
cd scGPT && pip install -e . --no-deps
6370
6471
runners:
6572
- type: executable

src/methods/scprint/config.vsh.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ engines:
7575
setup:
7676
- type: python
7777
pip:
78-
- scprint>=2.3.0
79-
- gseapy>=1.1.8
78+
- scprint==2.3.5
8079
- type: docker
8180
run: |
8281
lamin init --storage ./main --name main --schema bionty && \
@@ -87,6 +86,7 @@ engines:
8786
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()
8887
runners:
8988
- type: executable
89+
# docker_run_args: --gpus all
9090
- type: nextflow
9191
directives:
9292
label: [hightime, highmem, midcpu, gpu, highsharedmem]

src/methods/scprint/script.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from huggingface_hub import hf_hub_download
88
from scdataloader import Preprocessor
9+
from scdataloader.utils import load_genes
910
from scprint import scPrint
1011
from scprint.tasks import Embedder
1112

@@ -63,34 +64,55 @@
6364
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
6465
)
6566

67+
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
68+
6669
if torch.cuda.is_available():
6770
print("CUDA is available, using GPU", flush=True)
68-
precision = "16"
69-
dtype = torch.float16
7071
transformer = "flash"
7172
else:
7273
print("CUDA is not available, using CPU", flush=True)
73-
precision = "32"
74-
dtype = torch.float32
7574
transformer = "normal"
7675

77-
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
76+
try:
77+
m = torch.load(model_checkpoint_file)
78+
# if not use this instead since the model weights are by default mapped to GPU types
79+
except RuntimeError:
80+
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
7881

79-
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
82+
# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
83+
if "prenorm" in m["hyper_parameters"]:
84+
m["hyper_parameters"].pop("prenorm")
85+
torch.save(m, model_checkpoint_file)
8086
if "label_counts" in m["hyper_parameters"]:
87+
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
8188
model = scPrint.load_from_checkpoint(
8289
model_checkpoint_file,
83-
transformer=transformer, # Don't use this for GPUs with flashattention
8490
precpt_gene_emb=None,
8591
classes=m["hyper_parameters"]["label_counts"],
92+
transformer=transformer,
8693
)
8794
else:
8895
model = scPrint.load_from_checkpoint(
89-
model_checkpoint_file,
90-
transformer=transformer, # Don't use this for GPUs with flashattention
91-
precpt_gene_emb=None,
96+
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
9297
)
9398
del m
99+
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
100+
missing = set(model.genes) - set(load_genes(model.organisms).index)
101+
if len(missing) > 0:
102+
print(
103+
"Warning: some genes missmatch exist between model and ontology: solving...",
104+
)
105+
model._rm_genes(missing)
106+
107+
# again if not on GPU you need to convert the model to float32
108+
if not torch.cuda.is_available():
109+
model = model.to(torch.float32)
110+
111+
# you can perform your inference on float16 if you have a GPU, otherwise use float64
112+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
113+
114+
# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
115+
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
94116

95117
print("\n>>> Embedding data...", flush=True)
96118
n_cores = min(len(os.sched_getaffinity(0)), 24)
@@ -107,7 +129,6 @@
107129
keep_all_cls_pred=False,
108130
output_expression="none",
109131
save_every=30_000,
110-
precision=precision,
111132
dtype=dtype,
112133
)
113134
embedded, _ = embedder(model, adata, cache=False)

0 commit comments

Comments
 (0)