|
6 | 6 | import torch |
7 | 7 | from huggingface_hub import hf_hub_download |
8 | 8 | from scdataloader import Preprocessor |
| 9 | +from scdataloader.utils import load_genes |
9 | 10 | from scprint import scPrint |
10 | 11 | from scprint.tasks import Embedder |
11 | 12 |
|
|
63 | 64 | repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" |
64 | 65 | ) |
65 | 66 |
|
| 67 | +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) |
| 68 | + |
66 | 69 | if torch.cuda.is_available(): |
67 | 70 | print("CUDA is available, using GPU", flush=True) |
68 | | - precision = "16" |
69 | | - dtype = torch.float16 |
70 | 71 | transformer = "flash" |
71 | 72 | else: |
72 | 73 | print("CUDA is not available, using CPU", flush=True) |
73 | | - precision = "32" |
74 | | - dtype = torch.float32 |
75 | 74 | transformer = "normal" |
76 | 75 |
|
77 | | -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) |
| 76 | +try: |
| 77 | + m = torch.load(model_checkpoint_file) |
| 78 | +# if not use this instead since the model weights are by default mapped to GPU types |
| 79 | +except RuntimeError: |
| 80 | + m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) |
78 | 81 |
|
79 | | -m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) |
| 82 | +# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer |
| 83 | +if "prenorm" in m["hyper_parameters"]: |
| 84 | + m["hyper_parameters"].pop("prenorm") |
| 85 | + torch.save(m, model_checkpoint_file) |
80 | 86 | if "label_counts" in m["hyper_parameters"]: |
| 87 | + # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model |
81 | 88 | model = scPrint.load_from_checkpoint( |
82 | 89 | model_checkpoint_file, |
83 | | - transformer=transformer, # Don't use this for GPUs with flashattention |
84 | 90 | precpt_gene_emb=None, |
85 | 91 | classes=m["hyper_parameters"]["label_counts"], |
| 92 | + transformer=transformer, |
86 | 93 | ) |
87 | 94 | else: |
88 | 95 | model = scPrint.load_from_checkpoint( |
89 | | - model_checkpoint_file, |
90 | | - transformer=transformer, # Don't use this for GPUs with flashattention |
91 | | - precpt_gene_emb=None, |
| 96 | + model_checkpoint_file, precpt_gene_emb=None, transformer=transformer |
92 | 97 | ) |
93 | 98 | del m |
| 99 | +# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology |
| 100 | +missing = set(model.genes) - set(load_genes(model.organisms).index) |
| 101 | +if len(missing) > 0: |
| 102 | + print( |
| 103 | + "Warning: some genes missmatch exist between model and ontology: solving...", |
| 104 | + ) |
| 105 | + model._rm_genes(missing) |
| 106 | + |
| 107 | +# again if not on GPU you need to convert the model to float32 |
| 108 | +if not torch.cuda.is_available(): |
| 109 | + model = model.to(torch.float32) |
| 110 | + |
| 111 | +# you can perform your inference on float16 if you have a GPU, otherwise use float64 |
| 112 | +dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| 113 | + |
| 114 | +# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device |
| 115 | +model = model.to("cuda" if torch.cuda.is_available() else "cpu") |
94 | 116 |
|
95 | 117 | print("\n>>> Embedding data...", flush=True) |
96 | 118 | n_cores = min(len(os.sched_getaffinity(0)), 24) |
|
107 | 129 | keep_all_cls_pred=False, |
108 | 130 | output_expression="none", |
109 | 131 | save_every=30_000, |
110 | | - precision=precision, |
111 | 132 | dtype=dtype, |
112 | 133 | ) |
113 | 134 | embedded, _ = embedder(model, adata, cache=False) |
|
0 commit comments