|
| 1 | +import os |
1 | 2 | import sys |
| 3 | +import tarfile |
2 | 4 | import tempfile |
| 5 | +import zipfile |
3 | 6 |
|
4 | 7 | import anndata as ad |
5 | 8 | import gdown |
|
12 | 15 | par = { |
13 | 16 | "input": "resources_test/.../input.h5ad", |
14 | 17 | "output": "output.h5ad", |
| 18 | + "model_name": "scGPT_human", |
15 | 19 | "model": "scGPT_human", |
16 | 20 | "n_hvg": 3000, |
17 | 21 | } |
|
43 | 47 |
|
44 | 48 | print(adata, flush=True) |
45 | 49 |
|
46 | | -print(f"\n>>> Downloading '{par['model']}' model...", flush=True) |
47 | | -model_drive_ids = { |
48 | | - "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", |
49 | | - "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", |
50 | | -} |
51 | | -drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}" |
52 | | -model_dir = tempfile.TemporaryDirectory() |
53 | | -print(f"Downloading from '{drive_path}'", flush=True) |
54 | | -gdown.download_folder(drive_path, output=model_dir.name, quiet=True) |
55 | | -print(f"Model directory: '{model_dir.name}'", flush=True) |
| 50 | +if par["model"] is None: |
| 51 | + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) |
| 52 | + model_drive_ids = { |
| 53 | + "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", |
| 54 | + "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", |
| 55 | + } |
| 56 | + drive_path = ( |
| 57 | + f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}" |
| 58 | + ) |
| 59 | + model_temp = tempfile.TemporaryDirectory() |
| 60 | + model_dir = model_temp.name |
| 61 | + print(f"Downloading from '{drive_path}'", flush=True) |
| 62 | + gdown.download_folder(drive_path, output=model_dir, quiet=True) |
| 63 | +else: |
| 64 | + if os.path.isdir(par["model"]): |
| 65 | + print(f"\n>>> Using model directory...", flush=True) |
| 66 | + model_temp = None |
| 67 | + model_dir = par["model"] |
| 68 | + else: |
| 69 | + model_temp = tempfile.TemporaryDirectory() |
| 70 | + model_dir = model_temp.name |
| 71 | + |
| 72 | + if zipfile.is_zipfile(par["model"]): |
| 73 | + print(f"\n>>> Extracting model from .zip...", flush=True) |
| 74 | + print(f".zip path: '{par['model']}'", flush=True) |
| 75 | + with zipfile.ZipFile(par["model"], "r") as zip_file: |
| 76 | + zip_file.extractall(model_dir) |
| 77 | + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith( |
| 78 | + ".tar.gz" |
| 79 | + ): |
| 80 | + print(f"\n>>> Extracting model from .tar.gz...", flush=True) |
| 81 | + print(f".tar.gz path: '{par['model']}'", flush=True) |
| 82 | + with tarfile.open(par["model"], "r:gz") as tar_file: |
| 83 | + tar_file.extractall(model_dir) |
| 84 | + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) |
| 85 | + else: |
| 86 | + raise ValueError( |
| 87 | + f"The 'model' argument should be a directory a .zip file or a .tar.gz file" |
| 88 | + ) |
| 89 | + |
| 90 | +print(f"Model directory: '{model_dir}'", flush=True) |
56 | 91 |
|
57 | 92 | print("\n>>> Embedding data...", flush=True) |
58 | 93 | device = "cuda" if torch.cuda.is_available() else "cpu" |
59 | 94 | print(f"Device: '{device}'", flush=True) |
60 | 95 | embedded = scgpt.tasks.embed_data( |
61 | 96 | adata, |
62 | | - model_dir.name, |
| 97 | + model_dir, |
63 | 98 | gene_col="feature_name", |
64 | 99 | batch_size=64, |
65 | 100 | use_fast_transformer=False, # Disable fast-attn as not installed |
|
86 | 121 | print(f"Output H5AD file: '{par['output']}'", flush=True) |
87 | 122 | output.write_h5ad(par["output"], compression="gzip") |
88 | 123 |
|
89 | | -print("\n>>> Cleaning up temporary directories...", flush=True) |
90 | | -model_dir.cleanup() |
| 124 | +if model_temp is not None: |
| 125 | + print("\n>>> Cleaning up temporary directories...", flush=True) |
| 126 | + model_temp.cleanup() |
91 | 127 |
|
92 | 128 | print("\n>>> Done!", flush=True) |
0 commit comments