From ffba95460cbea7fe2746d469cffe46e8f641d6f4 Mon Sep 17 00:00:00 2001 From: jkobject Date: Wed, 5 Mar 2025 17:43:08 +0100 Subject: [PATCH 01/13] better loading of the model and new model --- src/methods/scprint/config.vsh.yaml | 19 ++++----- src/methods/scprint/script.py | 63 ++++++++++++++++++----------- 2 files changed, 47 insertions(+), 35 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 05746c99..f526da7a 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -38,7 +38,7 @@ info: scprint_large: model_name: "large" scprint_medium: - model_name: "medium" + model_name: "v2-medium" scprint_small: model_name: "small" test_setup: @@ -51,8 +51,8 @@ arguments: - name: "--model_name" type: "string" description: Which model to use. Not used if --model is provided. - choices: ["large", "medium", "small"] - default: "large" + choices: ["large", "v2-medium", "small"] + default: "v2-medium" - name: --model type: file description: Path to the scPRINT model. @@ -60,11 +60,11 @@ arguments: - name: --batch_size type: integer description: The size of the batches to be used in the DataLoader. - default: 64 + default: 32 - name: --max_len type: integer description: The maximum length of the gene sequence. - default: 2000 + default: 4000 resources: - type: python_script @@ -77,17 +77,14 @@ engines: setup: - type: python pip: - - huggingface_hub - # Can be unpinned after https://github.com/cantinilab/scPRINT/issues/14 is resolved - - scprint==1.6.2 - - scdataloader==1.6.4 + - scprint - type: docker run: | lamin init --storage ./main --name main --schema bionty - - type: python - script: import bionty as bt; bt.core.sync_all_sources_to_latest() - type: docker run: lamin load anonymous/main + - type: python + script: import bionty as bt; bt.core.sync_all_sources_to_latest() - type: python script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index cf58b7ef..a80d1396 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -14,7 +14,7 @@ "input_train": "resources_test/task_label_projection/cxg_immune_cell_atlas/train.h5ad", "input_test": "resources_test/task_label_projection/cxg_immune_cell_atlas/test.h5ad", "output": "output.h5ad", - "model_name": "large", + "model_name": "v2-medium", "model": None, } meta = {"name": "scprint"} @@ -29,8 +29,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" print("\n>>> Reading input data...", flush=True) -input_train = ad.read_h5ad(par['input_train']) -input_test = ad.read_h5ad(par['input_test']) +input_train = ad.read_h5ad(par["input_train"]) +input_test = ad.read_h5ad(par["input_test"]) input_test_uns = input_test.uns.copy() print("\n>>> Preprocessing input data...", flush=True) @@ -44,7 +44,7 @@ else: exit_non_applicable( f"scPRINT can only be used with human data " - f"(dataset_organism == \"{input_train.uns['dataset_organism']}\")" + f'(dataset_organism == "{input_train.uns["dataset_organism"]}")' ) # move data @@ -59,7 +59,9 @@ # applying preprocessor preprocessor = Preprocessor( # Lower this threshold for test datasets - min_valid_genes_id=min(0.9 * input_train.n_vars, 10000), # 90% of features up to 10,000 + min_valid_genes_id=min( + 0.9 * input_train.n_vars, 10000 + ), # 90% of features up to 10,000 # Turn off cell filtering to return results for all cells filter_cell_by_counts=False, min_nnz_genes=False, @@ -79,11 +81,22 @@ ) print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) -model = scprint.scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer="normal", # Don't use this for GPUs with flashattention - precpt_gene_emb=None, -) + +m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +if "label_counts" in m["hyper_parameters"]: + model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer=transformer, # Don't use this for GPUs with flashattention + precpt_gene_emb=None, + classes=m["hyper_parameters"]["label_counts"], + ) +else: + model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer=transformer, # Don't use this for GPUs with flashattention + precpt_gene_emb=None, + ) +del m print("\n>>> Embedding train data...", flush=True) if torch.cuda.is_available(): @@ -110,7 +123,7 @@ dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, - output_expression="none" + output_expression="none", ) embedded, _ = embedder(model, input_train, cache=False) @@ -156,7 +169,7 @@ # we remove any that have been matched and repeat until all predicted labels are # matched to a dataset label. print("\n---- INFERRED MATCHES ----", flush=True) -print(f"{'PREDICTED' : <40}{'LABEL' : <40}", flush=True) +print(f"{'PREDICTED': <40}{'LABEL': <40}", flush=True) while not all(pred in matches for pred in predicted_levels): # Get predicted labels that have not yet been matched not_matched = [pred for pred in predicted_levels if pred not in matches.keys()] @@ -171,11 +184,11 @@ label_level = label_levels[label] matches[predicted_level] = label_level - if (len(predicted_level) > 39): - predicted_level = predicted_level[:36] + '...' + if len(predicted_level) > 39: + predicted_level = predicted_level[:36] + "..." - if (len(label_level) > 39): - label_level = label_level[:36] + '...' + if len(label_level) > 39: + label_level = label_level[:36] + "..." print(f"{predicted_level: <40}{label_level: <40}", flush=True) @@ -193,22 +206,24 @@ dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, - output_expression="none" + output_expression="none", ) embedded_test, _ = embedder(model, input_test, cache=False) print("\n>>> Converting predictions to labels...", flush=True) -input_test.obs["label_pred"] = embedded_test.obs["conv_pred_cell_type_ontology_term_id"].values +input_test.obs["label_pred"] = embedded_test.obs[ + "conv_pred_cell_type_ontology_term_id" +].values input_test.obs = input_test.obs.replace(dict(label_pred=matches)) print("\n>>> Storing output...", flush=True) output = ad.AnnData( - obs=input_test.obs[["label_pred"]], - uns={ - 'method_id': meta['name'], - 'dataset_id': input_test_uns['dataset_id'], - 'normalization_id': input_test_uns['normalization_id'] - } + obs=input_test.obs[["label_pred"]], + uns={ + "method_id": meta["name"], + "dataset_id": input_test_uns["dataset_id"], + "normalization_id": input_test_uns["normalization_id"], + }, ) print("\n>>> Writing output AnnData to file...", flush=True) From 9c212cc70db47d64587ba7fa5241eb71a0ce1a07 Mon Sep 17 00:00:00 2001 From: jkobject Date: Mon, 10 Mar 2025 12:11:22 +0100 Subject: [PATCH 02/13] new solution to matching problem --- src/methods/scprint/config.vsh.yaml | 6 ++- src/methods/scprint/script.py | 78 ++++++++++++++++------------- 2 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index f526da7a..2a89fd21 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -44,7 +44,7 @@ info: test_setup: run: model_name: small - batch_size: 64 + batch_size: 16 max_len: 100 arguments: @@ -77,7 +77,8 @@ engines: setup: - type: python pip: - - scprint + - scprint==2.2.1 + - gseapy==1.1.2 - type: docker run: | lamin init --storage ./main --name main --schema bionty @@ -90,6 +91,7 @@ engines: runners: - type: executable + docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index a80d1396..b6c150d9 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -6,6 +6,7 @@ import os import sys import numpy as np +from scprint import scPrint from scipy.spatial import distance from scipy.optimize import linear_sum_assignment @@ -80,6 +81,17 @@ repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" ) +if torch.cuda.is_available(): + print("CUDA is available, using GPU", flush=True) + precision = "16" + dtype = torch.float16 + transformer = "flash" +else: + print("CUDA is not available, using CPU", flush=True) + precision = "32" + dtype = torch.float32 + transformer = "normal" + print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) @@ -98,25 +110,15 @@ ) del m -print("\n>>> Embedding train data...", flush=True) -if torch.cuda.is_available(): - print("CUDA is available, using GPU", flush=True) - precision = "16" - dtype = torch.float16 -else: - print("CUDA is not available, using CPU", flush=True) - precision = "32" - dtype = torch.float32 - -n_cores_available = len(os.sched_getaffinity(0)) +n_cores = max(16, len(os.sched_getaffinity(0))) -print(f"Using {n_cores_available} worker cores") +print(f"Using {n_cores} worker cores") embedder = scprint.tasks.Embedder( batch_size=par["batch_size"], how="random expr", max_len=par["max_len"], add_zero_genes=0, - num_workers=n_cores_available, + num_workers=n_cores, doclass=True, doplot=False, precision=precision, @@ -170,27 +172,33 @@ # matched to a dataset label. print("\n---- INFERRED MATCHES ----", flush=True) print(f"{'PREDICTED': <40}{'LABEL': <40}", flush=True) -while not all(pred in matches for pred in predicted_levels): - # Get predicted labels that have not yet been matched - not_matched = [pred for pred in predicted_levels if pred not in matches.keys()] - not_matched_idx = [predicted_levels.index(pred) for pred in not_matched] - - # Get assignments for currently unmatched predicted labels - assignments = linear_sum_assignment(jaccard[:, not_matched_idx]) - - # Store any new matches - for label, pred in zip(assignments[0], assignments[1]): - predicted_level = not_matched[pred] - label_level = label_levels[label] - matches[predicted_level] = label_level - - if len(predicted_level) > 39: - predicted_level = predicted_level[:36] + "..." - - if len(label_level) > 39: - label_level = label_level[:36] + "..." - - print(f"{predicted_level: <40}{label_level: <40}", flush=True) +for i, pred in enumerate(predicted_levels): + matches[pred] = label_levels[np.argmin(jaccard[:, i])] + print(f"{pred: <40}{matches[pred]: <40}", flush=True) + +# previous version + +# while not all(pred in matches for pred in predicted_levels): +# # Get predicted labels that have not yet been matched +# not_matched = [pred for pred in predicted_levels if pred not in matches.keys()] +# not_matched_idx = [predicted_levels.index(pred) for pred in not_matched] +# +# # Get assignments for currently unmatched predicted labels +# assignments = linear_sum_assignment(jaccard[:, not_matched_idx]) +# +# # Store any new matches +# for label, pred in zip(assignments[0], assignments[1]): +# predicted_level = not_matched[pred] +# label_level = label_levels[label] +# matches[predicted_level] = label_level +# +# if len(predicted_level) > 39: +# predicted_level = predicted_level[:36] + "..." +# +# if len(label_level) > 39: +# label_level = label_level[:36] + "..." +# +# print(f"{predicted_level: <40}{label_level: <40}", flush=True) print("\n>>> Embedding test data...", flush=True) @@ -199,7 +207,7 @@ how="random expr", max_len=par["max_len"], add_zero_genes=0, - num_workers=n_cores_available, + num_workers=n_cores, doclass=True, doplot=False, precision=precision, From 50c20e2c41453e9c4dbc6e8c1067611846c249be Mon Sep 17 00:00:00 2001 From: jkobject Date: Mon, 10 Mar 2025 12:11:44 +0100 Subject: [PATCH 03/13] removing gpu --- src/methods/scprint/config.vsh.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 2a89fd21..a464d8a6 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -91,7 +91,6 @@ engines: runners: - type: executable - docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] From 55f7528944aa2edce359f9ceda3f9161e6d12788 Mon Sep 17 00:00:00 2001 From: jkobject Date: Thu, 13 Mar 2025 19:02:59 +0100 Subject: [PATCH 04/13] proposing two versions --- src/methods/scprint/config.vsh.yaml | 11 ++++- src/methods/scprint/script.py | 62 ++++++++++++++++------------- 2 files changed, 44 insertions(+), 29 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index a464d8a6..c1ac38fb 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -65,6 +65,13 @@ arguments: type: integer description: The maximum length of the gene sequence. default: 4000 + - name: --infer_matches + type: string + description: + The method to use to infer the matches between the predicted and + true labels. + choices: ["direct", "linear_sum_assignment"] + default: "direct" resources: - type: python_script @@ -77,8 +84,9 @@ engines: setup: - type: python pip: - - scprint==2.2.1 + - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - gseapy==1.1.2 + - git+https://github.com/jkobject/scDataLoader.git@0f9e1858c8a4c6b0239ceb00e762d52032d745e7 - type: docker run: | lamin init --storage ./main --name main --schema bionty @@ -91,6 +99,7 @@ engines: runners: - type: executable + docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index c9257f1a..62c7b2d8 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -17,6 +17,7 @@ "output": "output.h5ad", "model_name": "v2-medium", "model": None, + "infer_matches": "direct", } meta = {"name": "scprint"} ## VIASH END @@ -85,7 +86,7 @@ print("CUDA is available, using GPU", flush=True) precision = "16" dtype = torch.float16 - transformer="flash" + transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) precision = "32" @@ -170,37 +171,42 @@ # the Jaccard distances. This algorithm may not match all predicted labels so # we remove any that have been matched and repeat until all predicted labels are # matched to a dataset label. + print("\n---- INFERRED MATCHES ----", flush=True) print(f"{'PREDICTED': <40}{'LABEL': <40}", flush=True) -for i, pred in enumerate(predicted_levels): - matches[pred] = label_levels[np.argmin(jaccard[:, i])] - print(f"{pred: <40}{matches[pred]: <40}", flush=True) - -# previous version - -# while not all(pred in matches for pred in predicted_levels): -# # Get predicted labels that have not yet been matched -# not_matched = [pred for pred in predicted_levels if pred not in matches.keys()] -# not_matched_idx = [predicted_levels.index(pred) for pred in not_matched] -# -# # Get assignments for currently unmatched predicted labels -# assignments = linear_sum_assignment(jaccard[:, not_matched_idx]) -# -# # Store any new matches -# for label, pred in zip(assignments[0], assignments[1]): -# predicted_level = not_matched[pred] -# label_level = label_levels[label] -# matches[predicted_level] = label_level -# -# if len(predicted_level) > 39: -# predicted_level = predicted_level[:36] + "..." -# -# if len(label_level) > 39: -# label_level = label_level[:36] + "..." -# -# print(f"{predicted_level: <40}{label_level: <40}", flush=True) +if par["infer_matches"] == "direct": + # other version + for i, pred in enumerate(predicted_levels): + matches[pred] = label_levels[np.argmin(jaccard[:, i])] + print(f"{pred: <40}{matches[pred]: <40}", flush=True) + +elif par["infer_matches"] == "linear_sum_assignment": + # previous version + while not all(pred in matches for pred in predicted_levels): + # Get predicted labels that have not yet been matched + not_matched = [pred for pred in predicted_levels if pred not in matches.keys()] + not_matched_idx = [predicted_levels.index(pred) for pred in not_matched] + # Get assignments for currently unmatched predicted labels + assignments = linear_sum_assignment(jaccard[:, not_matched_idx]) + # Store any new matches + for label, pred in zip(assignments[0], assignments[1]): + predicted_level = not_matched[pred] + label_level = label_levels[label] + matches[predicted_level] = label_level + if len(predicted_level) > 39: + predicted_level = predicted_level[:36] + "..." + if len(label_level) > 39: + label_level = label_level[:36] + "..." + print(f"{predicted_level: <40}{label_level: <40}", flush=True) +else: + raise ValueError(f"Invalid value for infer_matches: {par['infer_matches']}") +print("\n---- UNMATCHED TRUE LABELS ----", flush=True) +for l in lower_label_levels: + if l not in [m.lower() for m in matches.values()]: + print(l) + print("\n>>> Embedding test data...", flush=True) embedder = scprint.tasks.Embedder( batch_size=par["batch_size"], From ae3626fd38fb314225631a10c2593b44fe0859c1 Mon Sep 17 00:00:00 2001 From: jkobject Date: Fri, 14 Mar 2025 09:59:41 +0100 Subject: [PATCH 05/13] removing flag --- src/methods/scprint/config.vsh.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index c1ac38fb..838b034f 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -99,7 +99,6 @@ engines: runners: - type: executable - docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] From d1372bce5317e97467a04a86875a22808098cbc8 Mon Sep 17 00:00:00 2001 From: jkobject Date: Fri, 14 Mar 2025 17:08:07 +0100 Subject: [PATCH 06/13] default to linear sum --- src/methods/scprint/config.vsh.yaml | 4 ++-- src/methods/scprint/script.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 838b034f..aabfc943 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -71,7 +71,7 @@ arguments: The method to use to infer the matches between the predicted and true labels. choices: ["direct", "linear_sum_assignment"] - default: "direct" + default: "linear_sum_assignment" resources: - type: python_script @@ -86,7 +86,7 @@ engines: pip: - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - gseapy==1.1.2 - - git+https://github.com/jkobject/scDataLoader.git@0f9e1858c8a4c6b0239ceb00e762d52032d745e7 + - git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa - type: docker run: | lamin init --storage ./main --name main --schema bionty diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 62c7b2d8..bc139a35 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -17,7 +17,7 @@ "output": "output.h5ad", "model_name": "v2-medium", "model": None, - "infer_matches": "direct", + "infer_matches": "linear_sum_assignment", } meta = {"name": "scprint"} ## VIASH END From d55892e26ca11de80eb02c1eb2f6a06e9dba2575 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 18 Mar 2025 10:20:23 +0100 Subject: [PATCH 07/13] Style and lint scPRINT script --- src/methods/scprint/script.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index bc139a35..7881a900 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -1,14 +1,15 @@ -import anndata as ad -from scdataloader import Preprocessor -from huggingface_hub import hf_hub_download -import scprint -import torch import os import sys + +import anndata as ad import numpy as np -from scprint import scPrint -from scipy.spatial import distance +import scprint +import torch +from huggingface_hub import hf_hub_download +from scdataloader import Preprocessor from scipy.optimize import linear_sum_assignment +from scipy.spatial import distance +from scprint import scPrint ## VIASH START par = { @@ -146,7 +147,7 @@ # If there are any predicted labels that exactly match a dataset label we use them directly matches = {} -lower_label_levels = [l.lower() for l in label_levels] +lower_label_levels = [lbl.lower() for lbl in label_levels] print("---- EXACT MATCHES ----", flush=True) for pred in predicted_levels: if pred.lower() in lower_label_levels: @@ -161,7 +162,7 @@ combos = [(label, pred) for label in label_levels for pred in predicted_levels] for label, pred in combos: - labels_bin = [1 if l == label else 0 for l in label_values] + labels_bin = [1 if lbl == label else 0 for lbl in label_values] predicted_bin = [1 if p == pred else 0 for p in predicted_values] label_idx = label_levels.index(label) predicted_idx = predicted_levels.index(pred) @@ -203,9 +204,9 @@ else: raise ValueError(f"Invalid value for infer_matches: {par['infer_matches']}") print("\n---- UNMATCHED TRUE LABELS ----", flush=True) -for l in lower_label_levels: - if l not in [m.lower() for m in matches.values()]: - print(l) +for lbl in lower_label_levels: + if lbl not in [m.lower() for m in matches.values()]: + print(lbl, flush=True) print("\n>>> Embedding test data...", flush=True) embedder = scprint.tasks.Embedder( From af50d0006000cb3279a26f5108e4f2e4c7ebdea5 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 18 Mar 2025 10:45:42 +0100 Subject: [PATCH 08/13] Update CHANGELOG --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06d1498e..8f562f93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +# task_label_projection devel + +* Update scPRINT to better handle large datasets, including a new default model (PR #20) + # task_label_projection 2.0.0 A major update to the OpenProblems framework, switching from a Python-based framework to a Viash + Nextflow-based framework. This update features the same concepts as the previous version, but with a new implementation that is more flexible, scalable, and maintainable. From 4538bab7e303493ec39cc1795b1c441d1f2d764b Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Tue, 12 Aug 2025 16:00:39 +0200 Subject: [PATCH 09/13] updating scprint version --- src/methods/scprint/config.vsh.yaml | 5 ++- src/methods/scprint/script.py | 48 +++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index aabfc943..8cacdc4f 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -84,9 +84,7 @@ engines: setup: - type: python pip: - - git+https://github.com/cantinilab/scPRINT.git@d8cc270b099c8d5dacf6913acc26f2b696685b2b - - gseapy==1.1.2 - - git+https://github.com/jkobject/scDataLoader.git@c67c24a2e5c62399912be39169aae76e29e108aa + - scprint==2.3.5 - type: docker run: | lamin init --storage ./main --name main --schema bionty @@ -99,6 +97,7 @@ engines: runners: - type: executable + docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index 7881a900..bfd6c810 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -7,6 +7,7 @@ import torch from huggingface_hub import hf_hub_download from scdataloader import Preprocessor +from scdataloader.utils import load_genes from scipy.optimize import linear_sum_assignment from scipy.spatial import distance from scprint import scPrint @@ -34,7 +35,9 @@ print("\n>>> Reading input data...", flush=True) input_train = ad.read_h5ad(par["input_train"]) input_test = ad.read_h5ad(par["input_test"]) + input_test_uns = input_test.uns.copy() +input_test_obs = input_test.obs.copy() print("\n>>> Preprocessing input data...", flush=True) # store organism ontology term id @@ -85,32 +88,52 @@ if torch.cuda.is_available(): print("CUDA is available, using GPU", flush=True) - precision = "16" - dtype = torch.float16 transformer = "flash" else: print("CUDA is not available, using CPU", flush=True) - precision = "32" - dtype = torch.float32 transformer = "normal" -print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) - -m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) +# make sure that you check if you have a GPU with flashattention or not (see README) +try: + m = torch.load(model_checkpoint_file) +# if not use this instead since the model weights are by default mapped to GPU types +except RuntimeError: + m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) + +# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer +if "prenorm" in m["hyper_parameters"]: + m["hyper_parameters"].pop("prenorm") + torch.save(m, model_checkpoint_file) if "label_counts" in m["hyper_parameters"]: + # you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model model = scPrint.load_from_checkpoint( model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention precpt_gene_emb=None, classes=m["hyper_parameters"]["label_counts"], + transformer=transformer, ) else: model = scPrint.load_from_checkpoint( - model_checkpoint_file, - transformer=transformer, # Don't use this for GPUs with flashattention - precpt_gene_emb=None, + model_checkpoint_file, precpt_gene_emb=None, transformer=transformer ) del m +# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology +missing = set(model.genes) - set(load_genes(model.organisms).index) +if len(missing) > 0: + print( + "Warning: some genes missmatch exist between model and ontology: solving...", + ) + model._rm_genes(missing) + +# again if not on GPU you need to convert the model to float64 +if not torch.cuda.is_available(): + model = model.to(torch.float32) + +# you can perform your inference on float16 if you have a GPU, otherwise use float64 +dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + +# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device +model = model.to("cuda" if torch.cuda.is_available() else "cpu") n_cores = max(16, len(os.sched_getaffinity(0))) @@ -123,7 +146,6 @@ num_workers=n_cores, doclass=True, doplot=False, - precision=precision, dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, @@ -217,7 +239,6 @@ num_workers=n_cores, doclass=True, doplot=False, - precision=precision, dtype=dtype, pred_embedding=["cell_type_ontology_term_id"], keep_all_cls_pred=False, @@ -230,6 +251,7 @@ "conv_pred_cell_type_ontology_term_id" ].values input_test.obs = input_test.obs.replace(dict(label_pred=matches)) +input_test.obs.index = input_test_obs.index print("\n>>> Storing output...", flush=True) output = ad.AnnData( From f13b4a1abe8954448ff32b5001a4a1216f8e3f88 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Tue, 12 Aug 2025 16:06:46 +0200 Subject: [PATCH 10/13] removing gpu device --- src/methods/scprint/config.vsh.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 48bc045a..9e6adac8 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -96,7 +96,7 @@ engines: runners: - type: executable - docker_run_args: --gpus all + #docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] From 62e72f45807973ef198d0057686e75576feebdf9 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 29 Sep 2025 10:58:12 +0200 Subject: [PATCH 11/13] better scgpt installation (now uses flash attention) --- src/methods/scgpt_finetuned/config.vsh.yaml | 12 ++++++------ src/methods/scgpt_zeroshot/config.vsh.yaml | 21 ++++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml index c333c68e..cc287b9b 100644 --- a/src/methods/scgpt_finetuned/config.vsh.yaml +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -37,13 +37,13 @@ engines: - type: docker image: openproblems/base_pytorch_nvidia:1 setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies - type: docker - # Force re-installing from GitHub to get bug fixes - run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + run: | + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps runners: - type: executable diff --git a/src/methods/scgpt_zeroshot/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml index f7e6353c..2f4d9e62 100644 --- a/src/methods/scgpt_zeroshot/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -43,15 +43,22 @@ engines: - type: docker image: openproblems/base_pytorch_nvidia:1 setup: - - type: python - pypi: - - gdown - - scgpt # Install from PyPI to get dependencies - - faiss-cpu # TODO: Try installing faiss-gpu - type: docker - # Force re-installing from GitHub to get bug fixes run: | - pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + git clone https://github.com/bowang-lab/scGPT && \ + pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ + pip install "flash-attn<1.0.5" --no-build-isolation && \ + pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ + cd scGPT && pip install -e . --no-deps + # - type: python + # pypi: + # - gdown + # - scgpt # Install from PyPI to get dependencies + # - faiss-cpu # TODO: Try installing faiss-gpu + # - type: docker + # # Force re-installing from GitHub to get bug fixes + # run: | + # pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable From 7c785a68cd7a83ad33ca7f106f6f85b9fb4a22b5 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 13 Oct 2025 13:21:02 +0200 Subject: [PATCH 12/13] changing default parameters --- src/methods/scprint/config.vsh.yaml | 6 +++--- src/methods/scprint/script.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml index 9e6adac8..ca502815 100644 --- a/src/methods/scprint/config.vsh.yaml +++ b/src/methods/scprint/config.vsh.yaml @@ -64,14 +64,14 @@ arguments: - name: --max_len type: integer description: The maximum length of the gene sequence. - default: 4000 + default: 2000 - name: --infer_matches type: string description: The method to use to infer the matches between the predicted and true labels. choices: ["direct", "linear_sum_assignment"] - default: "linear_sum_assignment" + default: "direct" resources: - type: python_script @@ -96,7 +96,7 @@ engines: runners: - type: executable - #docker_run_args: --gpus all + # docker_run_args: --gpus all - type: nextflow directives: label: [hightime, highmem, midcpu, gpu, highsharedmem] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py index bfd6c810..753643a5 100644 --- a/src/methods/scprint/script.py +++ b/src/methods/scprint/script.py @@ -19,7 +19,7 @@ "output": "output.h5ad", "model_name": "v2-medium", "model": None, - "infer_matches": "linear_sum_assignment", + "infer_matches": "direct", } meta = {"name": "scprint"} ## VIASH END From f6adff55b3ed5274a6812a08cc887ad66e9b4b35 Mon Sep 17 00:00:00 2001 From: jkobject-tower Date: Mon, 13 Oct 2025 13:40:02 +0200 Subject: [PATCH 13/13] removing scgpt --- src/methods/scgpt_finetuned/config.vsh.yaml | 12 ++++++------ src/methods/scgpt_zeroshot/config.vsh.yaml | 21 +++++++-------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml index cc287b9b..c333c68e 100644 --- a/src/methods/scgpt_finetuned/config.vsh.yaml +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -37,13 +37,13 @@ engines: - type: docker image: openproblems/base_pytorch_nvidia:1 setup: + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies - type: docker - run: | - git clone https://github.com/bowang-lab/scGPT && \ - pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ - pip install "flash-attn<1.0.5" --no-build-isolation && \ - pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ - cd scGPT && pip install -e . --no-deps + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable diff --git a/src/methods/scgpt_zeroshot/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml index 2f4d9e62..f7e6353c 100644 --- a/src/methods/scgpt_zeroshot/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -43,22 +43,15 @@ engines: - type: docker image: openproblems/base_pytorch_nvidia:1 setup: + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies + - faiss-cpu # TODO: Try installing faiss-gpu - type: docker + # Force re-installing from GitHub to get bug fixes run: | - git clone https://github.com/bowang-lab/scGPT && \ - pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \ - pip install "flash-attn<1.0.5" --no-build-isolation && \ - pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \ - cd scGPT && pip install -e . --no-deps - # - type: python - # pypi: - # - gdown - # - scgpt # Install from PyPI to get dependencies - # - faiss-cpu # TODO: Try installing faiss-gpu - # - type: docker - # # Force re-installing from GitHub to get bug fixes - # run: | - # pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git runners: - type: executable