Skip to content

Commit 56ed6b7

Browse files
authored
Merge pull request #24 from openproblems-bio/add_resolvi
Add resolvi
2 parents 5b83d8e + c95a4a9 commit 56ed6b7

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
__merge__: /src/api/comp_method_expression_correction.yaml
2+
3+
name: resolvi_correction
4+
label: "resolVI Correction"
5+
summary: "Corrects the expression of genes using resolVI"
6+
description: >-
7+
Corrects the expression of genes based on the resolVI method, a part of scvi-tools.
8+
links:
9+
documentation: "https://docs.scvi-tools.org/en/latest/user_guide/models/resolvi.html"
10+
repository: "https://github.com/scverse/scvi-tools"
11+
references:
12+
doi: "10.1101/2025.01.20.634005"
13+
14+
arguments:
15+
- name: --celltype_key
16+
required: false
17+
direction: input
18+
type: string
19+
default: cell_type
20+
21+
- name: --n_hidden
22+
required: false
23+
direction: input
24+
type: integer
25+
default: 32
26+
27+
- name: --encode_covariates
28+
required: false
29+
direction: input
30+
type: boolean
31+
default: false
32+
33+
- name: --downsample_counts
34+
required: false
35+
direction: input
36+
type: boolean
37+
default: true
38+
39+
resources:
40+
- type: python_script
41+
path: script.py
42+
43+
engines:
44+
- type: docker
45+
image: openproblems/base_python:1.0.0
46+
__merge__:
47+
- /src/base/setup_txsim_partial.yaml
48+
setup:
49+
- type: python
50+
pypi: [scvi-tools]
51+
- type: native
52+
53+
runners:
54+
- type: executable
55+
- type: nextflow
56+
directives:
57+
label: [ midtime, highcpu, highmem ]
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import anndata as ad
2+
import txsim as tx
3+
import scvi
4+
import pandas as pd
5+
import scanpy as sc
6+
import scipy
7+
import numpy as np
8+
9+
## VIASH START
10+
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
11+
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
12+
par = {
13+
'input_spatial_with_cell_types': 'resources_test/task_ist_preprocessing/mouse_brain_combined/spatial_with_cell_types.h5ad',
14+
'celltype_key': 'cell_type',
15+
'output': '../resolvi_spatial_corrected.h5ad',
16+
'n_hidden': 32,
17+
'encode_covariates': False,
18+
'downsample_counts': True
19+
}
20+
meta = {
21+
'name': 'resolvi_correction',
22+
}
23+
## VIASH END
24+
25+
# NOTE/TODO: for grid search:
26+
# - n_hidden: 32 (default), 64, 128
27+
# - encode_covariates: False(default)/True
28+
# - downsample_counts: True(default)/False
29+
30+
# Optional parameter check: For this specific correction method the par['input_sc'] is required
31+
32+
# Read input
33+
print('Reading input files', flush=True)
34+
adata_sp = ad.read_h5ad(par['input_spatial_with_cell_types'])
35+
adata_sp.layers["normalized_uncorrected"] = adata_sp.layers["normalized"]
36+
37+
print("Filter cells with <5 counts")
38+
sc.pp.filter_cells(adata_sp, min_genes=5)
39+
40+
spatial_array = np.stack([adata_sp.obs['centroid_x'].values, adata_sp.obs['centroid_y'].values], axis=1)
41+
adata_sp.obsm['X_spatial'] = spatial_array
42+
43+
# Apply gene efficiency correction
44+
print('Running ResolVI', flush=True)
45+
46+
scvi.external.RESOLVI.setup_anndata(adata_sp, labels_key=par['celltype_key'], layer="counts")
47+
48+
supervised_resolvi = scvi.external.RESOLVI(adata_sp, semisupervised=True,
49+
n_hidden = par['n_hidden'],
50+
encode_covariates = par['encode_covariates'],
51+
downsample_counts = par['downsample_counts'])
52+
supervised_resolvi.train(max_epochs=50)
53+
54+
samples_corr = supervised_resolvi.sample_posterior(
55+
model=supervised_resolvi.module.model_corrected,
56+
return_sites=['px_rate'],
57+
summary_fun={"post_sample_q50": np.median},
58+
num_samples=20, return_samples=False, batch_size=4000) #batch_steps was not a parameter
59+
samples_corr = pd.DataFrame(samples_corr).T
60+
61+
samples = supervised_resolvi.sample_posterior(
62+
model=supervised_resolvi.module.model_residuals,
63+
return_sites=[
64+
'mixture_proportions', 'mean_poisson', 'per_gene_background',
65+
'diffusion_mixture_proportion', 'per_neighbor_diffusion', 'px_r_inv'
66+
],
67+
num_samples=20, return_samples=False, batch_size=4000)
68+
samples = pd.DataFrame(samples).T
69+
70+
71+
adata_sp.obsm["X_resolVI"] = supervised_resolvi.get_latent_representation()
72+
73+
# TODO these 2 lines threw errors because 'obs' was not generated in samples_corr
74+
# adata_sp.layers["generated_expression"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_q25', 'obs'])
75+
# adata_sp.layers["generated_expression_mean"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_means', 'obs'])
76+
77+
adata_sp.layers["corrected_counts"] = adata_sp.layers['counts'].multiply((samples_corr.loc['post_sample_q50', 'px_rate'] / (
78+
1.0 + samples_corr.loc['post_sample_q50', 'px_rate'] + samples.loc['post_sample_means', 'mean_poisson']))).tocsr()
79+
80+
# Write output
81+
print('Writing output', flush=True)
82+
adata_sp.write(par['output'])

0 commit comments

Comments
 (0)