Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions _viash.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ authors:
info:
github: jacorvar
orcid: 0000-0002-7373-5433
- name: Jeremie Kalfon
roles: [contributor]
info:
github: jkobject
orcid: 0000-0002-2818-9728

# Step 7: Remove all of the comments of the steps you completed
# Step 8: High five yourself!
Expand Down
94 changes: 94 additions & 0 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
__merge__: /src/api/comp_method.yaml

name: scprint
label: scPRINT
summary: scPRINT is a large transformer model built for the inference of gene networks
description: |
scPRINT is a large transformer model built for the inference of gene networks
(connections between genes explaining the cell's expression profile) from
scRNAseq data.

It uses novel encoding and decoding of the cell expression profile and new
pre-training methodologies to learn a cell model.

scPRINT can be used to perform the following analyses:

- expression denoising: increase the resolution of your scRNAseq data
- cell embedding: generate a low-dimensional representation of your dataset
- label prediction: predict the cell type, disease, sequencer, sex, and
ethnicity of your cells
- gene network inference: generate a gene network from any cell or cell
cluster in your scRNAseq dataset

references:
doi:
- 10.1101/2024.07.29.605556

links:
documentation: https://cantinilab.github.io/scPRINT/
repository: https://github.com/cantinilab/scPRINT

info:
preferred_normalization: counts
method_types: [embedding]
variants:
scprint_large:
model_name: "large"
scprint_medium:
model_name: "v2-medium"
scprint_small:
model_name: "small"
test_setup:
run:
model_name: small
batch_size: 16
max_len: 100

arguments:
- name: "--model_name"
type: "string"
description: Which model to use. Not used if --model is provided.
choices: ["large", "v2-medium", "small"]
default: "v2-medium"
- name: --model
type: file
description: Path to the scPRINT model.
required: false
- name: --batch_size
type: integer
description: The size of the batches to be used in the DataLoader.
default: 32
- name: --max_len
type: integer
description: The maximum length of the gene sequence.
default: 4000

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py
- path: /src/utils/exit_codes.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- scprint==2.2.1
- gseapy==1.1.2
- type: docker
run: lamin init --storage ./main --name main --schema bionty
- type: docker
run: lamin load anonymous/main
- type: python
script: import bionty as bt; bt.core.sync_all_sources_to_latest()
- type: python
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()

runners:
- type: executable
docker_run_args: --gpus all
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to remove this for the checks to pass

- type: nextflow
directives:
label: [hightime, highmem, midcpu, gpu, highsharedmem]
128 changes: 128 additions & 0 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import sys

import anndata as ad
import scprint
import torch
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scprint import scPrint
from scprint.tasks import Embedder

## VIASH START
par = {
"input": "resources_test/task_dimensionality_reduction/cxg_mouse_pancreas_atlas/dataset.h5ad",
"output": "reduced.h5ad",
"model_name": "v2-medium",
"model": None,
}
meta = {"name": "scprint"}
## VIASH END

sys.path.append(meta["resources_dir"])
from exit_codes import exit_non_applicable
from read_anndata_partial import read_anndata

print(f"====== scPRINT version {scprint.__version__} ======", flush=True)

# Set suggested PyTorch environment variable
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

print("\n>>> Reading input data...", flush=True)
input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")
if input.uns["dataset_organism"] == "homo_sapiens":
input.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
elif input.uns["dataset_organism"] == "mus_musculus":
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
else:
exit_non_applicable(
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
)
adata = input.copy()

print("\n>>> Preprocessing data...", flush=True)
preprocessor = Preprocessor(
min_valid_genes_id=min(0.9 * adata.n_vars, 10000), # 90% of features up to 10,000
# Turn off cell filtering to return results for all cells
filter_cell_by_counts=False,
min_nnz_genes=False,
do_postp=False,
# Skip ontology checks
skip_validate=True,
)
adata = preprocessor(adata)

model_checkpoint_file = par["model"]
if model_checkpoint_file is None:
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
model_checkpoint_file = hf_hub_download(
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
)

print("\n>>> Embedding data...", flush=True)
if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16"
dtype = torch.float16
transformer = "flash"
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
transformer = "normal"

print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)

m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
if "label_counts" in m["hyper_parameters"]:
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
classes=m["hyper_parameters"]["label_counts"],
)
else:
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer=transformer, # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
)
del m

n_cores = min(len(os.sched_getaffinity(0)), 24)
print(f"Using {n_cores} worker cores")
embedder = Embedder(
how="random expr",
batch_size=par["batch_size"],
max_len=par["max_len"],
add_zero_genes=0,
num_workers=n_cores,
pred_embedding=[], # none means using all
doclass=False,
doplot=False,
keep_all_cls_pred=False,
output_expression="none",
precision=precision,
dtype=dtype,
)
embedded, _ = embedder(model, adata, cache=False)

print("\n>>> Storing output...", flush=True)
output = ad.AnnData(
obs=input.obs[[]],
var=input.var[[]],
obsm={
"X_emb": embedded.obsm["scprint_emb"],
},
uns={
"dataset_id": input.uns["dataset_id"],
"normalization_id": input.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Done!", flush=True)
7 changes: 7 additions & 0 deletions src/utils/exit_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import sys

# when the method is not applicable to the input data,
# exit with code 99
def exit_non_applicable(msg):
print(f"NON-APPLICABLE ERROR: {msg}", flush=True)
sys.exit(99)
77 changes: 77 additions & 0 deletions src/utils/read_anndata_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import warnings
from pathlib import Path
import anndata as ad
import h5py
from scipy.sparse import csr_matrix
from anndata.experimental import read_elem, sparse_dataset


def read_anndata(
file: str,
backed: bool = False,
**kwargs
) -> ad.AnnData:
"""
Read anndata file
:param file: path to anndata file in h5ad format
:param kwargs: AnnData parameter to group mapping
"""
assert Path(file).exists(), f'File not found: {file}'

f = h5py.File(file, 'r')
kwargs = {x: x for x in f} if not kwargs else kwargs
if len(f.keys()) == 0:
return ad.AnnData()
# check if keys are available
for name, slot in kwargs.items():
if slot not in f:
warnings.warn(
f'Cannot find "{slot}" for AnnData parameter `{name}` from "{file}"'
)
adata = read_partial(f, backed=backed, **kwargs)
if not backed:
f.close()

return adata


def read_partial(
group: h5py.Group,
backed: bool = False,
force_sparse_types: [str, list] = None,
**kwargs
) -> ad.AnnData:
"""
Partially read h5py groups
:params group: file group
:params force_sparse_types: encoding types to convert to sparse_dataset via csr_matrix
:params backed: read sparse matrix as sparse_dataset
:params **kwargs: dict of slot_name: slot, by default use all available slot for the h5py file
:return: AnnData object
"""
if force_sparse_types is None:
force_sparse_types = []
elif isinstance(force_sparse_types, str):
force_sparse_types = [force_sparse_types]
slots = {}
if backed:
print('Read as backed sparse matrix...')

for slot_name, slot in kwargs.items():
print(f'Read slot "{slot}", store as "{slot_name}"...')
if slot not in group:
warnings.warn(f'Slot "{slot}" not found, skip...')
slots[slot_name] = None
else:
elem = group[slot]
iospec = ad._io.specs.get_spec(elem)
if iospec.encoding_type in ("csr_matrix", "csc_matrix") and backed:
slots[slot_name] = sparse_dataset(elem)
elif iospec.encoding_type in force_sparse_types:
slots[slot_name] = csr_matrix(read_elem(elem))
if backed:
slots[slot_name] = sparse_dataset(slots[slot_name])
else:
slots[slot_name] = read_elem(elem)
return ad.AnnData(**slots)

Loading