55import torch
66from mudata import AnnData , MuData
77from pandas import DataFrame
8- from pyro .infer import TraceEnum_ELBO
8+ from pyro import poutine
9+ from pyro .infer import TraceEnum_ELBO , infer_discrete
910from scipy .sparse import issparse
1011from scipy .stats import chi2
1112from scvi ._types import AnnOrMuData
1213from scvi .data import AnnDataManager , fields
13- from scvi .dataloaders import DeviceBackedDataSplitter
14+ from scvi .dataloaders import AnnDataLoader , DeviceBackedDataSplitter
1415from scvi .model .base import (
1516 BaseModelClass ,
1617 PyroSampleMixin ,
1718 PyroSviTrainMixin ,
1819)
19- from pyro import poutine
20- from pyro .infer import infer_discrete
21- from scvi .dataloaders import AnnDataLoader
22-
2320from scvi .train import PyroTrainingPlan
2421from scvi .utils ._docstrings import devices_dsp
2522
@@ -33,7 +30,7 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass):
3330 def __init__ (
3431 self ,
3532 mdata : AnnOrMuData ,
36- # control_guides=None,
33+ control_guides = None ,
3734 ** model_kwargs ,
3835 ):
3936 super ().__init__ (mdata )
@@ -57,7 +54,6 @@ def __init__(
5754 self .data_and_attrs .update ({REGISTRY_KEYS .CAT_COVS_KEY : np .float32 })
5855 n_cats_per_cov = self .adata_manager .get_state_registry (REGISTRY_KEYS .CAT_COVS_KEY ).n_cats_per_key
5956
60-
6157 guide_by_element = None
6258 n_elements = None
6359 if REGISTRY_KEYS .GUIDE_BY_ELEMENT_KEY in self .adata_manager .data_registry :
@@ -68,8 +64,31 @@ def __init__(
6864 if REGISTRY_KEYS .GENE_BY_ELEMENT_KEY in self .adata_manager .data_registry :
6965 gene_by_element = self .read_matrix_from_registry (REGISTRY_KEYS .GENE_BY_ELEMENT_KEY )
7066
71- gene_mean = self .adata_manager .get_from_registry (REGISTRY_KEYS .GENE_SUMMARY_STATS )
72- gene_mean = torch .tensor (gene_mean , dtype = torch .float32 )[:, 0 ]
67+ epsilon = 1e-3
68+ X = self .adata_manager .get_from_registry (REGISTRY_KEYS .X_KEY )
69+ grna_counts = self .adata_manager .get_from_registry (REGISTRY_KEYS .PERTURBATION_KEY )
70+
71+ if control_guides is not None :
72+ if issparse (grna_counts ):
73+ control_guide_idx = grna_counts [:, control_guides ].X .sum (axis = 1 ).A1 > 0
74+ else :
75+ control_guide_idx = grna_counts [:, control_guides ].X .sum (axis = 1 ) > 0
76+ X = X [control_guide_idx , :]
77+
78+ if issparse (X ):
79+ sample_mean = X .mean (axis = 0 ).A1 + epsilon
80+ sample_mean_squared = sample_mean * sample_mean
81+ sample_var = (X .multiply (X )).mean (axis = 0 ).A1 - sample_mean_squared
82+ else :
83+ sample_mean = X .mean (axis = 0 ).squeeze () + epsilon
84+ sample_mean_squared = sample_mean ** 2
85+ sample_var = (X ** 2 ).mean (axis = 0 ).squeeze () - sample_mean_squared
86+
87+ theta_hat = torch .tensor (sample_mean_squared / (sample_var - sample_mean )).clamp (min = 1e-1 )
88+ init_values = {
89+ "log_gene_mean" : torch .tensor (sample_mean , dtype = torch .float32 ).log (),
90+ "log_gene_dispersion" : torch .tensor (theta_hat ).log (),
91+ }
7392 # if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None:
7493 # # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1)
7594 # control_mask = self.read_matrix_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, control_guides].sum(dim=-1)
@@ -85,7 +104,7 @@ def __init__(
85104 n_genes = self .summary_stats .n_vars ,
86105 n_cont_covariates = n_extra_continuous_covs ,
87106 n_elements = n_elements ,
88- gene_means = gene_mean ,
107+ init_values = init_values ,
89108 guide_by_element = guide_by_element ,
90109 gene_by_element = gene_by_element ,
91110 # n_cats_per_cov=n_cats_per_cov,
@@ -116,59 +135,6 @@ def setup_anndata(
116135 ):
117136 raise NotImplementedError ("MuData input required, use setup_mudata." )
118137
119- # setup_method_args = cls._get_setup_method_args(**locals())
120- # anndata_fields = [
121- # fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True),
122- # fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
123- # fields.CategoricalObsField(REGISTRY_KEYS.PERTURBATION_KEY, perturbation_key),
124- # fields.NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
125- # fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariates_keys),
126- # ]
127- # # add library size if not present
128- # if library_size_key is None:
129- # library_size_key = "_library_size"
130- # library_size = adata.X.sum(axis=1)
131- # if not library_size.all():
132- # raise ValueError(
133- # "Cannot infer library size: cells with zero counts. Set library_size_key manually instead."
134- # )
135- # adata.obs[library_size_key] = library_size
136-
137- # # add size factor if not present
138- # if size_factor_key is None:
139- # size_factor_key = "_size_factor"
140- # library_size = adata.obs[library_size_key]
141- # if not library_size.all():
142- # raise ValueError(
143- # "Cannot infer size factors: cells with zero library size. Set size_factor_key manually instead."
144- # )
145- # log_cpm = np.log(library_size / 1e6)
146- # adata.obs[size_factor_key] = log_cpm - log_cpm.mean()
147-
148- # # add indices to enable pyro subsampling of local vars
149- # adata.obs = adata.obs.assign(_ind_x=lambda x: np.arange(len(x)))
150- # index_field = fields.MuDataNumericalObsField(
151- # REGISTRY_KEYS.INDICES_KEY,
152- # "_ind_x",
153- # )
154-
155- # # add info for method of moments estimation of gene params
156- # mean_counts = np.mean(adata.X, axis=0)
157- # if isinstance(mean_counts, np.matrix): # occurs when summing sparse array
158- # mean_counts = mean_counts.A1
159- # adata.var["_gene_mean"] = mean_counts
160- # # rna_adata.var["_gene_variance"] = np.var(rna_adata.X, axis=0).squeeze()
161- # gene_field = fields.MuDataNumericalVarField(
162- # REGISTRY_KEYS.GENE_SUMMARY_STATS,
163- # "_gene_mean",
164- # )
165-
166- # adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
167- # adata_manager.register_fields(adata, **kwargs)
168- # cls.register_manager(adata_manager)
169-
170- # raise NotImplementedError("MuData input required, use setup_mudata.")
171-
172138 @classmethod
173139 def setup_mudata (
174140 cls ,
@@ -256,20 +222,6 @@ def setup_mudata(
256222 mod_key = modalities .rna_layer ,
257223 )
258224
259- # add info for method of moments estimation of gene params
260- if gene_mean_key is None :
261- gene_mean_key = "_gene_mean"
262- rna_adata = mdata [modalities .rna_layer ]
263- mean_counts = np .mean (rna_adata .X , axis = 0 )
264- if isinstance (mean_counts , np .matrix ): # occurs when summing sparse array
265- mean_counts = mean_counts .A1
266- rna_adata .var ["_gene_mean" ] = mean_counts
267- # rna_adata.var["_gene_variance"] = np.var(rna_adata.X, axis=0).squeeze()
268- gene_field = fields .MuDataNumericalVarField (
269- REGISTRY_KEYS .GENE_SUMMARY_STATS ,
270- "_gene_mean" ,
271- mod_key = modalities .rna_layer ,
272- )
273225
274226 batch_field = fields .MuDataCategoricalObsField (
275227 REGISTRY_KEYS .BATCH_KEY ,
@@ -286,7 +238,6 @@ def setup_mudata(
286238 mudata_fields = [
287239 index_field ,
288240 batch_field ,
289- gene_field ,
290241 fields .MuDataLayerField (
291242 REGISTRY_KEYS .PERTURBATION_KEY ,
292243 perturbation_layer ,
@@ -346,15 +297,15 @@ def setup_mudata(
346297 @devices_dsp .dedent
347298 def train (
348299 self ,
349- max_epochs : int | None = None ,
300+ max_epochs : int = 1000 ,
350301 accelerator : str = "cpu" ,
351302 device : int | str = "auto" ,
352303 train_size : float = 1.0 ,
353304 validation_size : float | None = None ,
354305 shuffle_set_split : bool = False ,
355- batch_size : int = 128 ,
306+ batch_size : int = 1024 ,
356307 early_stopping : bool = False ,
357- lr : float | None = None ,
308+ lr : float | None = 0.005 ,
358309 training_plan : PyroTrainingPlan = PyroTrainingPlan ,
359310 plan_kwargs : dict | None = None ,
360311 data_splitter_kwargs : dict | None = None ,
0 commit comments