generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 3
adding scprint #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
adding scprint #17
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| __merge__: /src/api/comp_method.yaml | ||
|
|
||
| name: scprint | ||
| label: scPRINT | ||
| summary: scPRINT is a large transformer model built for the inference of gene networks | ||
| description: | | ||
| scPRINT is a large transformer model built for the inference of gene networks | ||
| (connections between genes explaining the cell's expression profile) from | ||
| scRNAseq data. | ||
|
|
||
| It uses novel encoding and decoding of the cell expression profile and new | ||
| pre-training methodologies to learn a cell model. | ||
|
|
||
| scPRINT can be used to perform the following analyses: | ||
|
|
||
| - expression denoising: increase the resolution of your scRNAseq data | ||
| - cell embedding: generate a low-dimensional representation of your dataset | ||
| - label prediction: predict the cell type, disease, sequencer, sex, and | ||
| ethnicity of your cells | ||
| - gene network inference: generate a gene network from any cell or cell | ||
| cluster in your scRNAseq dataset | ||
|
|
||
| references: | ||
| doi: | ||
| - 10.1101/2024.07.29.605556 | ||
|
|
||
| links: | ||
| documentation: https://cantinilab.github.io/scPRINT/ | ||
| repository: https://github.com/cantinilab/scPRINT | ||
|
|
||
| info: | ||
| preferred_normalization: counts | ||
| method_types: [embedding] | ||
| variants: | ||
| scprint_large: | ||
| model_name: "large" | ||
| scprint_medium: | ||
| model_name: "v2-medium" | ||
| scprint_small: | ||
| model_name: "small" | ||
| test_setup: | ||
| run: | ||
| model_name: small | ||
| batch_size: 16 | ||
| max_len: 100 | ||
|
|
||
| arguments: | ||
| - name: "--model_name" | ||
| type: "string" | ||
| description: Which model to use. Not used if --model is provided. | ||
| choices: ["large", "v2-medium", "small"] | ||
| default: "v2-medium" | ||
| - name: --model | ||
| type: file | ||
| description: Path to the scPRINT model. | ||
| required: false | ||
| - name: --batch_size | ||
| type: integer | ||
| description: The size of the batches to be used in the DataLoader. | ||
| default: 32 | ||
| - name: --max_len | ||
| type: integer | ||
| description: The maximum length of the gene sequence. | ||
| default: 4000 | ||
|
|
||
| resources: | ||
| - type: python_script | ||
| path: script.py | ||
| - path: /src/utils/read_anndata_partial.py | ||
| - path: /src/utils/exit_codes.py | ||
|
|
||
| engines: | ||
| - type: docker | ||
| image: openproblems/base_pytorch_nvidia:1.0.0 | ||
| setup: | ||
| - type: python | ||
| pip: | ||
| - scprint==2.2.1 | ||
| - gseapy==1.1.2 | ||
| - type: docker | ||
| run: lamin init --storage ./main --name main --schema bionty | ||
| - type: docker | ||
| run: lamin load anonymous/main | ||
| - type: python | ||
| script: import bionty as bt; bt.core.sync_all_sources_to_latest() | ||
| - type: python | ||
| script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() | ||
|
|
||
| runners: | ||
| - type: executable | ||
| docker_run_args: --gpus all | ||
| - type: nextflow | ||
| directives: | ||
| label: [hightime, highmem, midcpu, gpu, highsharedmem] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| import os | ||
| import sys | ||
|
|
||
| import anndata as ad | ||
| import scprint | ||
| import torch | ||
| from huggingface_hub import hf_hub_download | ||
| from scdataloader import Preprocessor | ||
| from scprint import scPrint | ||
| from scprint.tasks import Embedder | ||
|
|
||
| ## VIASH START | ||
| par = { | ||
| "input": "resources_test/task_dimensionality_reduction/cxg_mouse_pancreas_atlas/dataset.h5ad", | ||
| "output": "reduced.h5ad", | ||
| "model_name": "v2-medium", | ||
| "model": None, | ||
| } | ||
| meta = {"name": "scprint"} | ||
| ## VIASH END | ||
|
|
||
| sys.path.append(meta["resources_dir"]) | ||
| from exit_codes import exit_non_applicable | ||
| from read_anndata_partial import read_anndata | ||
|
|
||
| print(f"====== scPRINT version {scprint.__version__} ======", flush=True) | ||
|
|
||
| # Set suggested PyTorch environment variable | ||
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | ||
|
|
||
| print("\n>>> Reading input data...", flush=True) | ||
| input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") | ||
| if input.uns["dataset_organism"] == "homo_sapiens": | ||
| input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" | ||
| elif input.uns["dataset_organism"] == "mus_musculus": | ||
| input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" | ||
| else: | ||
| exit_non_applicable( | ||
| f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" | ||
| ) | ||
| adata = input.copy() | ||
|
|
||
| print("\n>>> Preprocessing data...", flush=True) | ||
| preprocessor = Preprocessor( | ||
| min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000 | ||
| # Turn off cell filtering to return results for all cells | ||
| filter_cell_by_counts=False, | ||
| min_nnz_genes=False, | ||
| do_postp=False, | ||
| # Skip ontology checks | ||
| skip_validate=True, | ||
| ) | ||
| adata = preprocessor(adata) | ||
|
|
||
| model_checkpoint_file = par["model"] | ||
| if model_checkpoint_file is None: | ||
| print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) | ||
| model_checkpoint_file = hf_hub_download( | ||
| repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" | ||
| ) | ||
|
|
||
| print("\n>>> Embedding data...", flush=True) | ||
| if torch.cuda.is_available(): | ||
| print("CUDA is available, using GPU", flush=True) | ||
| precision = "16" | ||
| dtype = torch.float16 | ||
| transformer = "flash" | ||
| else: | ||
| print("CUDA is not available, using CPU", flush=True) | ||
| precision = "32" | ||
| dtype = torch.float32 | ||
| transformer = "normal" | ||
|
|
||
| print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) | ||
|
|
||
| m = torch.load(model_checkpoint_file, map_location=torch.device("cpu")) | ||
| if "label_counts" in m["hyper_parameters"]: | ||
| model = scPrint.load_from_checkpoint( | ||
| model_checkpoint_file, | ||
| transformer=transformer, # Don't use this for GPUs with flashattention | ||
| precpt_gene_emb=None, | ||
| classes=m["hyper_parameters"]["label_counts"], | ||
| ) | ||
| else: | ||
| model = scPrint.load_from_checkpoint( | ||
| model_checkpoint_file, | ||
| transformer=transformer, # Don't use this for GPUs with flashattention | ||
| precpt_gene_emb=None, | ||
| ) | ||
| del m | ||
|
|
||
| n_cores = min(len(os.sched_getaffinity(0)), 24) | ||
| print(f"Using {n_cores} worker cores") | ||
| embedder = Embedder( | ||
| how="random expr", | ||
| batch_size=par["batch_size"], | ||
| max_len=par["max_len"], | ||
| add_zero_genes=0, | ||
| num_workers=n_cores, | ||
| pred_embedding=[], # none means using all | ||
| doclass=False, | ||
| doplot=False, | ||
| keep_all_cls_pred=False, | ||
| output_expression="none", | ||
| precision=precision, | ||
| dtype=dtype, | ||
| ) | ||
| embedded, _ = embedder(model, adata, cache=False) | ||
|
|
||
| print("\n>>> Storing output...", flush=True) | ||
| output = ad.AnnData( | ||
| obs=input.obs[[]], | ||
| var=input.var[[]], | ||
| obsm={ | ||
| "X_emb": embedded.obsm["scprint_emb"], | ||
| }, | ||
| uns={ | ||
| "dataset_id": input.uns["dataset_id"], | ||
| "normalization_id": input.uns["normalization_id"], | ||
| "method_id": meta["name"], | ||
| }, | ||
| ) | ||
| print(output) | ||
|
|
||
| print("\n>>> Writing output AnnData to file...", flush=True) | ||
| output.write_h5ad(par["output"], compression="gzip") | ||
|
|
||
| print("\n>>> Done!", flush=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| import sys | ||
|
|
||
| # when the method is not applicable to the input data, | ||
| # exit with code 99 | ||
| def exit_non_applicable(msg): | ||
| print(f"NON-APPLICABLE ERROR: {msg}", flush=True) | ||
| sys.exit(99) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| import warnings | ||
| from pathlib import Path | ||
| import anndata as ad | ||
| import h5py | ||
| from scipy.sparse import csr_matrix | ||
| from anndata.experimental import read_elem, sparse_dataset | ||
|
|
||
|
|
||
| def read_anndata( | ||
| file: str, | ||
| backed: bool = False, | ||
| **kwargs | ||
| ) -> ad.AnnData: | ||
| """ | ||
| Read anndata file | ||
| :param file: path to anndata file in h5ad format | ||
| :param kwargs: AnnData parameter to group mapping | ||
| """ | ||
| assert Path(file).exists(), f'File not found: {file}' | ||
|
|
||
| f = h5py.File(file, 'r') | ||
| kwargs = {x: x for x in f} if not kwargs else kwargs | ||
| if len(f.keys()) == 0: | ||
| return ad.AnnData() | ||
| # check if keys are available | ||
| for name, slot in kwargs.items(): | ||
| if slot not in f: | ||
| warnings.warn( | ||
| f'Cannot find "{slot}" for AnnData parameter `{name}` from "{file}"' | ||
| ) | ||
| adata = read_partial(f, backed=backed, **kwargs) | ||
| if not backed: | ||
| f.close() | ||
|
|
||
| return adata | ||
|
|
||
|
|
||
| def read_partial( | ||
| group: h5py.Group, | ||
| backed: bool = False, | ||
| force_sparse_types: [str, list] = None, | ||
| **kwargs | ||
| ) -> ad.AnnData: | ||
| """ | ||
| Partially read h5py groups | ||
| :params group: file group | ||
| :params force_sparse_types: encoding types to convert to sparse_dataset via csr_matrix | ||
| :params backed: read sparse matrix as sparse_dataset | ||
| :params **kwargs: dict of slot_name: slot, by default use all available slot for the h5py file | ||
| :return: AnnData object | ||
| """ | ||
| if force_sparse_types is None: | ||
| force_sparse_types = [] | ||
| elif isinstance(force_sparse_types, str): | ||
| force_sparse_types = [force_sparse_types] | ||
| slots = {} | ||
| if backed: | ||
| print('Read as backed sparse matrix...') | ||
|
|
||
| for slot_name, slot in kwargs.items(): | ||
| print(f'Read slot "{slot}", store as "{slot_name}"...') | ||
| if slot not in group: | ||
| warnings.warn(f'Slot "{slot}" not found, skip...') | ||
| slots[slot_name] = None | ||
| else: | ||
| elem = group[slot] | ||
| iospec = ad._io.specs.get_spec(elem) | ||
| if iospec.encoding_type in ("csr_matrix", "csc_matrix") and backed: | ||
| slots[slot_name] = sparse_dataset(elem) | ||
| elif iospec.encoding_type in force_sparse_types: | ||
| slots[slot_name] = csr_matrix(read_elem(elem)) | ||
| if backed: | ||
| slots[slot_name] = sparse_dataset(slots[slot_name]) | ||
| else: | ||
| slots[slot_name] = read_elem(elem) | ||
| return ad.AnnData(**slots) | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to remove this for the checks to pass