Skip to content

Commit c6bb27a

Browse files
seohyonkimmumichae
andauthored
Adding new method: DRVI (#61)
* script for drvi * add drvi to depenencies * add nvida image * changes after feedback * working DRVI mehtod * remove comments * remove comments, preprocessing * Update src/methods/drvi/script.py Co-authored-by: Michaela Müller <[email protected]> * Update src/methods/drvi/config.vsh.yaml Co-authored-by: Michaela Müller <[email protected]> * add changelog entry --------- Co-authored-by: Michaela Müller <[email protected]>
1 parent f738810 commit c6bb27a

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## New functionality
44

55
* Added `metrics/kbet_pg` and `metrics/kbet_pg_label` components (PR #52).
6+
* Added `method/drvi` component (PR #61).
67

78
## Minor changes
89

src/methods/drvi/config.vsh.yaml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
__merge__: ../../api/comp_method.yaml
2+
name: drvi
3+
label: DRVI
4+
summary: "DrVI is an unsupervised generative model capable of learning non-linear interpretable disentangled latent representations from single-cell count data."
5+
description: |
6+
Disentangled Representation Variational Inference (DRVI) is an unsupervised deep generative model designed for integrating single-cell RNA sequencing (scRNA-seq) data across different batches.
7+
It extends the variational autoencoder (VAE) framework by learning a latent representation that captures biological variation while disentangling and correcting for batch effects.
8+
DRVI conditions both the encoder and decoder on batch covariates, allowing it to explicitly model and mitigate batch-specific variations during training.
9+
By incorporating a KL-divergence regularization term, it balances data reconstruction with latent space structure, resulting in a unified embedding where similar cells cluster together regardless of batch.
10+
references:
11+
doi:
12+
- 10.1101/2024.11.06.622266
13+
links:
14+
documentation: https://drvi.readthedocs.io/latest/index.html
15+
repository: https://github.com/theislab/DRVI?tab=readme-ov-file
16+
info:
17+
preferred_normalization: counts
18+
arguments:
19+
- name: --n_hvg
20+
type: integer
21+
default: 2000
22+
description: Number of highly variable genes to use.
23+
- name: --n_epochs
24+
type: integer
25+
default: 100
26+
description: Number of epochs
27+
resources:
28+
- type: python_script
29+
path: script.py
30+
- path: /src/utils/read_anndata_partial.py
31+
engines:
32+
- type: docker
33+
image: openproblems/base_pytorch_nvidia:1.0.0
34+
setup:
35+
- type: python
36+
pypi:
37+
- drvi-py==0.1.7
38+
- torch==2.3.0
39+
- torchvision==0.18.0
40+
runners:
41+
- type: executable
42+
- type: nextflow
43+
directives:
44+
label: [midtime,midmem,lowcpu,gpu]

src/methods/drvi/script.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import anndata as ad
2+
import scanpy as sc
3+
import drvi
4+
from drvi.model import DRVI
5+
from drvi.utils.misc import hvg_batch
6+
import pandas as pd
7+
import numpy as np
8+
import warnings
9+
import sys
10+
import scipy.sparse
11+
12+
## VIASH START
13+
par = {
14+
'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad',
15+
'output': 'output.h5ad',
16+
'n_hvg': 2000,
17+
'n_epochs': 400
18+
}
19+
meta = {
20+
'name': 'drvi'
21+
}
22+
## VIASH END
23+
24+
sys.path.append(meta["resources_dir"])
25+
from read_anndata_partial import read_anndata
26+
27+
print('Reading input files', flush=True)
28+
adata = read_anndata(
29+
par['input'],
30+
X='layers/counts',
31+
obs='obs',
32+
var='var',
33+
uns='uns'
34+
)
35+
36+
if par["n_hvg"]:
37+
print(f"Select top {par['n_hvg']} high variable genes", flush=True)
38+
idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]]
39+
adata = adata[:, idx].copy()
40+
41+
print('Train model with DRVI', flush=True)
42+
43+
DRVI.setup_anndata(
44+
adata,
45+
categorical_covariate_keys=["batch"],
46+
is_count_data=False,
47+
)
48+
49+
model = DRVI(
50+
adata,
51+
categorical_covariates=["batch"],
52+
n_latent=128,
53+
encoder_dims=[128, 128],
54+
decoder_dims=[128, 128],
55+
)
56+
model
57+
58+
model.train(
59+
max_epochs=par["n_epochs"],
60+
early_stopping=False,
61+
early_stopping_patience=20,
62+
plan_kwargs={
63+
"n_epochs_kl_warmup": par["n_epochs"],
64+
},
65+
)
66+
67+
print("Store outputs", flush=True)
68+
output = ad.AnnData(
69+
obs=adata.obs.copy(),
70+
var=adata.var.copy(),
71+
obsm={
72+
"X_emb": model.get_latent_representation(),
73+
},
74+
uns={
75+
"dataset_id": adata.uns.get("dataset_id", "unknown"),
76+
"normalization_id": adata.uns.get("normalization_id", "unknown"),
77+
"method_id": meta["name"],
78+
},
79+
)
80+
81+
print("Write output AnnData to file", flush=True)
82+
output.write_h5ad(par['output'], compression='gzip')

0 commit comments

Comments
 (0)