Skip to content

Commit dd18949

Browse files
authored
Add argument to give model path to scGPT (#16)
* Add model path argument to scGPT * Use cached scGPT model in benchmark workflow * Make scGPT inherit from base method * Swap scGPT model argument names
1 parent 52ccedb commit dd18949

File tree

3 files changed

+62
-17
lines changed

3 files changed

+62
-17
lines changed

src/methods/scgpt/config.vsh.yaml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__merge__: ../../api/comp_method.yaml
1+
__merge__: ../../api/base_method.yaml
22

33
name: scgpt
44
label: scGPT
@@ -24,11 +24,18 @@ info:
2424
model: "scGPT_CP"
2525

2626
arguments:
27-
- name: --model
27+
- name: --model_name
2828
type: string
29-
description: String giving the scGPT model to use
29+
description: String giving the name of the scGPT model to use
3030
choices: ["scGPT_human", "scGPT_CP"]
3131
default: "scGPT_human"
32+
- name: --model
33+
type: file
34+
description: |
35+
Path to the directory containing the scGPT model specified by model_name
36+
or a .zip/.tar.gz archive to extract. If not given the model will be
37+
downloaded.
38+
required: false
3239
- name: --n_hvg
3340
type: integer
3441
default: 3000

src/methods/scgpt/script.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import os
12
import sys
3+
import tarfile
24
import tempfile
5+
import zipfile
36

47
import anndata as ad
58
import gdown
@@ -12,6 +15,7 @@
1215
par = {
1316
"input": "resources_test/.../input.h5ad",
1417
"output": "output.h5ad",
18+
"model_name": "scGPT_human",
1519
"model": "scGPT_human",
1620
"n_hvg": 3000,
1721
}
@@ -43,23 +47,54 @@
4347

4448
print(adata, flush=True)
4549

46-
print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
47-
model_drive_ids = {
48-
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
49-
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
50-
}
51-
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
52-
model_dir = tempfile.TemporaryDirectory()
53-
print(f"Downloading from '{drive_path}'", flush=True)
54-
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
55-
print(f"Model directory: '{model_dir.name}'", flush=True)
50+
if par["model"] is None:
51+
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
52+
model_drive_ids = {
53+
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
54+
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
55+
}
56+
drive_path = (
57+
f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}"
58+
)
59+
model_temp = tempfile.TemporaryDirectory()
60+
model_dir = model_temp.name
61+
print(f"Downloading from '{drive_path}'", flush=True)
62+
gdown.download_folder(drive_path, output=model_dir, quiet=True)
63+
else:
64+
if os.path.isdir(par["model"]):
65+
print(f"\n>>> Using model directory...", flush=True)
66+
model_temp = None
67+
model_dir = par["model"]
68+
else:
69+
model_temp = tempfile.TemporaryDirectory()
70+
model_dir = model_temp.name
71+
72+
if zipfile.is_zipfile(par["model"]):
73+
print(f"\n>>> Extracting model from .zip...", flush=True)
74+
print(f".zip path: '{par['model']}'", flush=True)
75+
with zipfile.ZipFile(par["model"], "r") as zip_file:
76+
zip_file.extractall(model_dir)
77+
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(
78+
".tar.gz"
79+
):
80+
print(f"\n>>> Extracting model from .tar.gz...", flush=True)
81+
print(f".tar.gz path: '{par['model']}'", flush=True)
82+
with tarfile.open(par["model"], "r:gz") as tar_file:
83+
tar_file.extractall(model_dir)
84+
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0])
85+
else:
86+
raise ValueError(
87+
f"The 'model' argument should be a directory a .zip file or a .tar.gz file"
88+
)
89+
90+
print(f"Model directory: '{model_dir}'", flush=True)
5691

5792
print("\n>>> Embedding data...", flush=True)
5893
device = "cuda" if torch.cuda.is_available() else "cpu"
5994
print(f"Device: '{device}'", flush=True)
6095
embedded = scgpt.tasks.embed_data(
6196
adata,
62-
model_dir.name,
97+
model_dir,
6398
gene_col="feature_name",
6499
batch_size=64,
65100
use_fast_transformer=False, # Disable fast-attn as not installed
@@ -86,7 +121,8 @@
86121
print(f"Output H5AD file: '{par['output']}'", flush=True)
87122
output.write_h5ad(par["output"], compression="gzip")
88123

89-
print("\n>>> Cleaning up temporary directories...", flush=True)
90-
model_dir.cleanup()
124+
if model_temp is not None:
125+
print("\n>>> Cleaning up temporary directories...", flush=True)
126+
model_temp.cleanup()
91127

92128
print("\n>>> Done!", flush=True)

src/workflows/run_benchmark/main.nf

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ methods = [
2929
scalex,
3030
scanorama,
3131
scanvi,
32-
scgpt,
32+
scgpt.run(
33+
args: [model_path: file("s3://openproblems-work/cache/scGPT_human.zip")]
34+
),
3335
scimilarity.run(
3436
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
3537
),

0 commit comments

Comments
 (0)