|
| 1 | +import anndata as ad |
| 2 | +import txsim as tx |
| 3 | +import scvi |
| 4 | +import pandas as pd |
| 5 | +import scanpy as sc |
| 6 | +import scipy |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +## VIASH START |
| 10 | +# Note: this section is auto-generated by viash at runtime. To edit it, make changes |
| 11 | +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. |
| 12 | +par = { |
| 13 | + 'input_spatial_with_cell_types': 'resources_test/task_ist_preprocessing/mouse_brain_combined/spatial_with_cell_types.h5ad', |
| 14 | + 'celltype_key': 'cell_type', |
| 15 | + 'output': '../resolvi_spatial_corrected.h5ad', |
| 16 | + 'n_hidden': 32, |
| 17 | + 'encode_covariates': False, |
| 18 | + 'downsample_counts': True |
| 19 | +} |
| 20 | +meta = { |
| 21 | + 'name': 'resolvi_correction', |
| 22 | +} |
| 23 | +## VIASH END |
| 24 | + |
| 25 | +# NOTE/TODO: for grid search: |
| 26 | +# - n_hidden: 32 (default), 64, 128 |
| 27 | +# - encode_covariates: False(default)/True |
| 28 | +# - downsample_counts: True(default)/False |
| 29 | + |
| 30 | +# Optional parameter check: For this specific correction method the par['input_sc'] is required |
| 31 | + |
| 32 | +# Read input |
| 33 | +print('Reading input files', flush=True) |
| 34 | +adata_sp = ad.read_h5ad(par['input_spatial_with_cell_types']) |
| 35 | +adata_sp.layers["normalized_uncorrected"] = adata_sp.layers["normalized"] |
| 36 | + |
| 37 | +print("Filter cells with <5 counts") |
| 38 | +sc.pp.filter_cells(adata_sp, min_genes=5) |
| 39 | + |
| 40 | +spatial_array = np.stack([adata_sp.obs['centroid_x'].values, adata_sp.obs['centroid_y'].values], axis=1) |
| 41 | +adata_sp.obsm['X_spatial'] = spatial_array |
| 42 | + |
| 43 | +# Apply gene efficiency correction |
| 44 | +print('Running ResolVI', flush=True) |
| 45 | + |
| 46 | +scvi.external.RESOLVI.setup_anndata(adata_sp, labels_key=par['celltype_key'], layer="counts") |
| 47 | + |
| 48 | +supervised_resolvi = scvi.external.RESOLVI(adata_sp, semisupervised=True, |
| 49 | + n_hidden = par['n_hidden'], |
| 50 | + encode_covariates = par['encode_covariates'], |
| 51 | + downsample_counts = par['downsample_counts']) |
| 52 | +supervised_resolvi.train(max_epochs=50) |
| 53 | + |
| 54 | +samples_corr = supervised_resolvi.sample_posterior( |
| 55 | + model=supervised_resolvi.module.model_corrected, |
| 56 | + return_sites=['px_rate'], |
| 57 | + summary_fun={"post_sample_q50": np.median}, |
| 58 | + num_samples=20, return_samples=False, batch_size=4000) #batch_steps was not a parameter |
| 59 | +samples_corr = pd.DataFrame(samples_corr).T |
| 60 | + |
| 61 | +samples = supervised_resolvi.sample_posterior( |
| 62 | + model=supervised_resolvi.module.model_residuals, |
| 63 | + return_sites=[ |
| 64 | + 'mixture_proportions', 'mean_poisson', 'per_gene_background', |
| 65 | + 'diffusion_mixture_proportion', 'per_neighbor_diffusion', 'px_r_inv' |
| 66 | + ], |
| 67 | + num_samples=20, return_samples=False, batch_size=4000) |
| 68 | +samples = pd.DataFrame(samples).T |
| 69 | + |
| 70 | + |
| 71 | +adata_sp.obsm["X_resolVI"] = supervised_resolvi.get_latent_representation() |
| 72 | + |
| 73 | +# TODO these 2 lines threw errors because 'obs' was not generated in samples_corr |
| 74 | +# adata_sp.layers["generated_expression"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_q25', 'obs']) |
| 75 | +# adata_sp.layers["generated_expression_mean"] = scipy.sparse.csr_matrix(samples_corr.loc['post_sample_means', 'obs']) |
| 76 | + |
| 77 | +adata_sp.layers["corrected_counts"] = adata_sp.layers['counts'].multiply((samples_corr.loc['post_sample_q50', 'px_rate'] / ( |
| 78 | + 1.0 + samples_corr.loc['post_sample_q50', 'px_rate'] + samples.loc['post_sample_means', 'mean_poisson']))).tocsr() |
| 79 | + |
| 80 | +# Write output |
| 81 | +print('Writing output', flush=True) |
| 82 | +adata_sp.write(par['output']) |
0 commit comments