Skip to content

Commit 374d455

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/no-ref/add-mlflow-models
2 parents 5e49b7c + 1c34b59 commit 374d455

File tree

15 files changed

+327
-30
lines changed

15 files changed

+327
-30
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
## New functionality
44

55
* Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52).
6+
* Added `methods/stacas` new method (PR #58).
7+
- Add non-supervised version of STACAS tool for integration of single-cell transcriptomics data. This functionality enables correction of batch effects while preserving biological variability without requiring prior cell type annotations.
68
* Added `method/drvi` component (PR #61).
9+
* Added `ARI_batch` and `NMI_batch` to `metrics/clustering_overlap` (PR #68).
10+
11+
* Added `metrics/cilisi` new metric component (PR #57).
12+
- ciLISI measures batch mixing in a cell type-aware manner by computing iLISI within each cell type and normalizing
13+
the scores between 0 and 1. Unlike iLISI, ciLISI preserves sensitivity to biological variance and avoids favoring
14+
overcorrected datasets with removed cell type signals.
15+
We propose adding this metric to substitute iLISI.
716

817
## Minor changes
918

@@ -12,7 +21,8 @@
1221

1322
## Bug fixes
1423

15-
* Update scPRINT to use latest stable version (PR #xx)
24+
* Update scPRINT to use latest stable version (PR #70)
25+
* Fix kbet dependencies to numpy<2 and scipy<=1.13 (PR #78).
1626

1727
# task_batch_integration 2.0.0
1828

scripts/render_report.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
#!/bin/bash
3+
4+
set -e
5+
6+
common/scripts/render_results_report "$@"

src/methods/scgpt_finetuned/config.vsh.yaml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,20 @@ engines:
5151
image: openproblems/base_pytorch_nvidia:1
5252
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
5353
setup:
54-
- type: python
55-
pypi:
56-
- gdown
57-
- scgpt # Install from PyPI to get dependencies
54+
#- type: python
55+
# pypi:
56+
# - gdown
57+
# - scgpt # Install from PyPI to get dependencies
58+
#- type: docker
59+
# # Force re-installing from GitHub to get bug fixes
60+
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
5861
- type: docker
59-
# Force re-installing from GitHub to get bug fixes
60-
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
62+
run: |
63+
git clone https://github.com/bowang-lab/scGPT && \
64+
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
65+
pip install "flash-attn<1.0.5" --no-build-isolation && \
66+
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
67+
cd scGPT && pip install -e . --no-deps
6168
6269
runners:
6370
- type: executable

src/methods/scgpt_zeroshot/config.vsh.yaml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,20 @@ engines:
5353
image: openproblems/base_pytorch_nvidia:1
5454
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
5555
setup:
56-
- type: python
57-
pypi:
58-
- gdown
59-
- scgpt # Install from PyPI to get dependencies
56+
#- type: python
57+
# pypi:
58+
# - gdown
59+
# - scgpt # Install from PyPI to get dependencies
60+
#- type: docker
61+
# # Force re-installing from GitHub to get bug fixes
62+
# run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
6063
- type: docker
61-
# Force re-installing from GitHub to get bug fixes
62-
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git
64+
run: |
65+
git clone https://github.com/bowang-lab/scGPT && \
66+
pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu121 && \
67+
pip install "flash-attn<1.0.5" --no-build-isolation && \
68+
pip install ipykernel pandas scanpy numba "numpy<1.24" torchtext==0.17.0 scib "scvi-tools<1.0" datasets==2.14.5 transformers==4.33.2 wandb "cell-gears<0.0.3" torch_geometric pyarrow==15.0.0 gdown && \
69+
cd scGPT && pip install -e . --no-deps
6370
6471
runners:
6572
- type: executable

src/methods/scprint/config.vsh.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ engines:
7575
setup:
7676
- type: python
7777
pip:
78-
- scprint>=2.3.0
79-
- gseapy>=1.1.8
78+
- scprint==2.3.5
8079
- type: docker
8180
run: |
8281
lamin init --storage ./main --name main --schema bionty && \
@@ -87,6 +86,7 @@ engines:
8786
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()
8887
runners:
8988
- type: executable
89+
# docker_run_args: --gpus all
9090
- type: nextflow
9191
directives:
9292
label: [hightime, highmem, midcpu, gpu, highsharedmem]

src/methods/scprint/script.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from huggingface_hub import hf_hub_download
88
from scdataloader import Preprocessor
9+
from scdataloader.utils import load_genes
910
from scprint import scPrint
1011
from scprint.tasks import Embedder
1112

@@ -63,34 +64,55 @@
6364
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
6465
)
6566

67+
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
68+
6669
if torch.cuda.is_available():
6770
print("CUDA is available, using GPU", flush=True)
68-
precision = "16"
69-
dtype = torch.float16
7071
transformer = "flash"
7172
else:
7273
print("CUDA is not available, using CPU", flush=True)
73-
precision = "32"
74-
dtype = torch.float32
7574
transformer = "normal"
7675

77-
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
76+
try:
77+
m = torch.load(model_checkpoint_file)
78+
# if not use this instead since the model weights are by default mapped to GPU types
79+
except RuntimeError:
80+
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
7881

79-
m = torch.load(model_checkpoint_file, map_location=torch.device("cpu"))
82+
# both are for compatibility issues with different versions of the pretrained model, so we need to load it with the correct transformer
83+
if "prenorm" in m["hyper_parameters"]:
84+
m["hyper_parameters"].pop("prenorm")
85+
torch.save(m, model_checkpoint_file)
8086
if "label_counts" in m["hyper_parameters"]:
87+
# you need to set precpt_gene_emb=None otherwise the model will look for its precomputed gene embeddings files although they were already converted into model weights, so you don't need this file for a pretrained model
8188
model = scPrint.load_from_checkpoint(
8289
model_checkpoint_file,
83-
transformer=transformer, # Don't use this for GPUs with flashattention
8490
precpt_gene_emb=None,
8591
classes=m["hyper_parameters"]["label_counts"],
92+
transformer=transformer,
8693
)
8794
else:
8895
model = scPrint.load_from_checkpoint(
89-
model_checkpoint_file,
90-
transformer=transformer, # Don't use this for GPUs with flashattention
91-
precpt_gene_emb=None,
96+
model_checkpoint_file, precpt_gene_emb=None, transformer=transformer
9297
)
9398
del m
99+
# this might happen if you have a model that was trained with a different set of genes than the one you are using in the ontology (e.g. newer ontologies), While having genes in the onlogy not in the model is fine. the opposite is not, so we need to remove the genes that are in the model but not in the ontology
100+
missing = set(model.genes) - set(load_genes(model.organisms).index)
101+
if len(missing) > 0:
102+
print(
103+
"Warning: some genes missmatch exist between model and ontology: solving...",
104+
)
105+
model._rm_genes(missing)
106+
107+
# again if not on GPU you need to convert the model to float32
108+
if not torch.cuda.is_available():
109+
model = model.to(torch.float32)
110+
111+
# you can perform your inference on float16 if you have a GPU, otherwise use float64
112+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
113+
114+
# the models are often loaded with some parts still displayed as "cuda" and some as "cpu", so we need to make sure that the model is fully on the right device
115+
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
94116

95117
print("\n>>> Embedding data...", flush=True)
96118
n_cores = min(len(os.sched_getaffinity(0)), 24)
@@ -107,7 +129,6 @@
107129
keep_all_cls_pred=False,
108130
output_expression="none",
109131
save_every=30_000,
110-
precision=precision,
111132
dtype=dtype,
112133
)
113134
embedded, _ = embedder(model, adata, cache=False)

src/methods/stacas/config.vsh.yaml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
__merge__: ../../api/comp_method.yaml
2+
name: stacas
3+
label: STACAS
4+
summary: Accurate semi-supervised integration of single-cell transcriptomics data
5+
description: |
6+
STACAS is a method for scRNA-seq integration,
7+
especially suited to accurately integrate datasets with large cell type imbalance
8+
(e.g. in terms of proportions of distinct cell populations).
9+
Prior cell type knowledge, given as cell type labels, can be provided to the algorithm to perform
10+
semi-supervised integration, leading to increased preservation of biological variability
11+
in the resulting integrated space.
12+
STACAS is robust to incomplete cell type labels and can be applied to large-scale integration tasks.
13+
references:
14+
doi: 10.1038/s41467-024-45240-z
15+
# Andreatta M, Hérault L, Gueguen P, Gfeller D, Berenstein AJ, Carmona SJ.
16+
# Semi-supervised integration of single-cell transcriptomics data.
17+
# Nature Communications*. 2024;15(1):1-13. doi:10.1038/s41467-024-45240-z
18+
links:
19+
documentation: https://carmonalab.github.io/STACAS.demo/STACAS.demo.html
20+
repository: https://github.com/carmonalab/STACAS
21+
info:
22+
preferred_normalization: log_cp10k
23+
method_types: [embedding]
24+
resources:
25+
- type: r_script
26+
path: script.R
27+
engines:
28+
- type: docker
29+
image: openproblems/base_r:1
30+
setup:
31+
- type: r
32+
github: carmonalab/STACAS@2.3.0
33+
runners:
34+
- type: executable
35+
- type: nextflow
36+
directives:
37+
label: [midtime,midmem,midcpu]

src/methods/stacas/script.R

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
requireNamespace("anndata", quietly = TRUE)
2+
suppressPackageStartupMessages({
3+
library(STACAS)
4+
library(Matrix)
5+
library(SeuratObject)
6+
library(Seurat)
7+
})
8+
9+
## VIASH START
10+
par <- list(
11+
input = "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
12+
output = "output.h5ad"
13+
)
14+
meta <- list(
15+
name = "stacas"
16+
)
17+
## VIASH END
18+
19+
cat("Reading input file\n")
20+
adata <- anndata::read_h5ad(par[["input"]])
21+
22+
cat("Create Seurat object\n")
23+
# Transpose because Seurat expects genes in rows, cells in columns
24+
counts_r <- Matrix::t(adata$layers[["counts"]])
25+
normalized_r <- Matrix::t(adata$layers[["normalized"]])
26+
# Convert to a regular sparse matrix first and then to dgCMatrix
27+
counts_c <- as(as(counts_r, "CsparseMatrix"), "dgCMatrix")
28+
normalized_c <- as(as(normalized_r, "CsparseMatrix"), "dgCMatrix")
29+
30+
# Create Seurat object with raw counts, these are needed to compute Variable Genes
31+
seurat_obj <- Seurat::CreateSeuratObject(counts = counts_c,
32+
meta.data = adata$obs)
33+
# Manually assign pre-normalized values to the "data" slot
34+
seurat_obj@assays$RNA$data <- normalized_c
35+
36+
cat("Run STACAS\n")
37+
object_integrated <- seurat_obj |>
38+
Seurat::SplitObject(split.by = "batch") |>
39+
STACAS::Run.STACAS()
40+
41+
cat("Store outputs\n")
42+
output <- anndata::AnnData(
43+
uns = list(
44+
dataset_id = adata$uns[["dataset_id"]],
45+
normalization_id = adata$uns[["normalization_id"]],
46+
method_id = meta$name
47+
),
48+
obs = adata$obs,
49+
var = adata$var,
50+
obsm = list(
51+
X_emb = object_integrated@reductions$pca@cell.embeddings
52+
)
53+
)
54+
55+
cat("Write output AnnData to file\n")
56+
output$write_h5ad(par[["output"]], compression = "gzip")

src/metrics/bras/config.vsh.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__merge__: ../../api/comp_metric.yaml
22
name: bras
33
info:
4+
metric_type: embedding
45
metrics:
56
- name: bras
67
label: BRAS

0 commit comments

Comments
 (0)