Skip to content

Commit f128816

Browse files
committed
add method of moments initialization for dispersion
1 parent 506c1be commit f128816

File tree

4 files changed

+101
-162
lines changed

4 files changed

+101
-162
lines changed

src/perturbo/models/_model.py

Lines changed: 33 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,18 @@
55
import torch
66
from mudata import AnnData, MuData
77
from pandas import DataFrame
8-
from pyro.infer import TraceEnum_ELBO
8+
from pyro import poutine
9+
from pyro.infer import TraceEnum_ELBO, infer_discrete
910
from scipy.sparse import issparse
1011
from scipy.stats import chi2
1112
from scvi._types import AnnOrMuData
1213
from scvi.data import AnnDataManager, fields
13-
from scvi.dataloaders import DeviceBackedDataSplitter
14+
from scvi.dataloaders import AnnDataLoader, DeviceBackedDataSplitter
1415
from scvi.model.base import (
1516
BaseModelClass,
1617
PyroSampleMixin,
1718
PyroSviTrainMixin,
1819
)
19-
from pyro import poutine
20-
from pyro.infer import infer_discrete
21-
from scvi.dataloaders import AnnDataLoader
22-
2320
from scvi.train import PyroTrainingPlan
2421
from scvi.utils._docstrings import devices_dsp
2522

@@ -33,7 +30,7 @@ class PERTURBO(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass):
3330
def __init__(
3431
self,
3532
mdata: AnnOrMuData,
36-
# control_guides=None,
33+
control_guides=None,
3734
**model_kwargs,
3835
):
3936
super().__init__(mdata)
@@ -57,7 +54,6 @@ def __init__(
5754
self.data_and_attrs.update({REGISTRY_KEYS.CAT_COVS_KEY: np.float32})
5855
n_cats_per_cov = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
5956

60-
6157
guide_by_element = None
6258
n_elements = None
6359
if REGISTRY_KEYS.GUIDE_BY_ELEMENT_KEY in self.adata_manager.data_registry:
@@ -68,8 +64,31 @@ def __init__(
6864
if REGISTRY_KEYS.GENE_BY_ELEMENT_KEY in self.adata_manager.data_registry:
6965
gene_by_element = self.read_matrix_from_registry(REGISTRY_KEYS.GENE_BY_ELEMENT_KEY)
7066

71-
gene_mean = self.adata_manager.get_from_registry(REGISTRY_KEYS.GENE_SUMMARY_STATS)
72-
gene_mean = torch.tensor(gene_mean, dtype=torch.float32)[:, 0]
67+
epsilon = 1e-3
68+
X = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY)
69+
grna_counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)
70+
71+
if control_guides is not None:
72+
if issparse(grna_counts):
73+
control_guide_idx = grna_counts[:, control_guides].X.sum(axis=1).A1 > 0
74+
else:
75+
control_guide_idx = grna_counts[:, control_guides].X.sum(axis=1) > 0
76+
X = X[control_guide_idx, :]
77+
78+
if issparse(X):
79+
sample_mean = X.mean(axis=0).A1 + epsilon
80+
sample_mean_squared = sample_mean * sample_mean
81+
sample_var = (X.multiply(X)).mean(axis=0).A1 - sample_mean_squared
82+
else:
83+
sample_mean = X.mean(axis=0).squeeze() + epsilon
84+
sample_mean_squared = sample_mean**2
85+
sample_var = (X**2).mean(axis=0).squeeze() - sample_mean_squared
86+
87+
theta_hat = torch.tensor(sample_mean_squared / (sample_var - sample_mean)).clamp(min=1e-1)
88+
init_values = {
89+
"log_gene_mean": torch.tensor(sample_mean, dtype=torch.float32).log(),
90+
"log_gene_dispersion": torch.tensor(theta_hat).log(),
91+
}
7392
# if control_guides is not None and "n_factors" in model_kwargs and guide_by_element is not None:
7493
# # control_guides, _ = torch.max(guide_by_element[:, control_elements], dim=-1)
7594
# control_mask = self.read_matrix_from_registry(REGISTRY_KEYS.PERTURBATION_KEY)[:, control_guides].sum(dim=-1)
@@ -85,7 +104,7 @@ def __init__(
85104
n_genes=self.summary_stats.n_vars,
86105
n_cont_covariates=n_extra_continuous_covs,
87106
n_elements=n_elements,
88-
gene_means=gene_mean,
107+
init_values=init_values,
89108
guide_by_element=guide_by_element,
90109
gene_by_element=gene_by_element,
91110
# n_cats_per_cov=n_cats_per_cov,
@@ -116,59 +135,6 @@ def setup_anndata(
116135
):
117136
raise NotImplementedError("MuData input required, use setup_mudata.")
118137

119-
# setup_method_args = cls._get_setup_method_args(**locals())
120-
# anndata_fields = [
121-
# fields.LayerField(REGISTRY_KEYS.X_KEY, None, is_count_data=True),
122-
# fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
123-
# fields.CategoricalObsField(REGISTRY_KEYS.PERTURBATION_KEY, perturbation_key),
124-
# fields.NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
125-
# fields.NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariates_keys),
126-
# ]
127-
# # add library size if not present
128-
# if library_size_key is None:
129-
# library_size_key = "_library_size"
130-
# library_size = adata.X.sum(axis=1)
131-
# if not library_size.all():
132-
# raise ValueError(
133-
# "Cannot infer library size: cells with zero counts. Set library_size_key manually instead."
134-
# )
135-
# adata.obs[library_size_key] = library_size
136-
137-
# # add size factor if not present
138-
# if size_factor_key is None:
139-
# size_factor_key = "_size_factor"
140-
# library_size = adata.obs[library_size_key]
141-
# if not library_size.all():
142-
# raise ValueError(
143-
# "Cannot infer size factors: cells with zero library size. Set size_factor_key manually instead."
144-
# )
145-
# log_cpm = np.log(library_size / 1e6)
146-
# adata.obs[size_factor_key] = log_cpm - log_cpm.mean()
147-
148-
# # add indices to enable pyro subsampling of local vars
149-
# adata.obs = adata.obs.assign(_ind_x=lambda x: np.arange(len(x)))
150-
# index_field = fields.MuDataNumericalObsField(
151-
# REGISTRY_KEYS.INDICES_KEY,
152-
# "_ind_x",
153-
# )
154-
155-
# # add info for method of moments estimation of gene params
156-
# mean_counts = np.mean(adata.X, axis=0)
157-
# if isinstance(mean_counts, np.matrix): # occurs when summing sparse array
158-
# mean_counts = mean_counts.A1
159-
# adata.var["_gene_mean"] = mean_counts
160-
# # rna_adata.var["_gene_variance"] = np.var(rna_adata.X, axis=0).squeeze()
161-
# gene_field = fields.MuDataNumericalVarField(
162-
# REGISTRY_KEYS.GENE_SUMMARY_STATS,
163-
# "_gene_mean",
164-
# )
165-
166-
# adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
167-
# adata_manager.register_fields(adata, **kwargs)
168-
# cls.register_manager(adata_manager)
169-
170-
# raise NotImplementedError("MuData input required, use setup_mudata.")
171-
172138
@classmethod
173139
def setup_mudata(
174140
cls,
@@ -256,20 +222,6 @@ def setup_mudata(
256222
mod_key=modalities.rna_layer,
257223
)
258224

259-
# add info for method of moments estimation of gene params
260-
if gene_mean_key is None:
261-
gene_mean_key = "_gene_mean"
262-
rna_adata = mdata[modalities.rna_layer]
263-
mean_counts = np.mean(rna_adata.X, axis=0)
264-
if isinstance(mean_counts, np.matrix): # occurs when summing sparse array
265-
mean_counts = mean_counts.A1
266-
rna_adata.var["_gene_mean"] = mean_counts
267-
# rna_adata.var["_gene_variance"] = np.var(rna_adata.X, axis=0).squeeze()
268-
gene_field = fields.MuDataNumericalVarField(
269-
REGISTRY_KEYS.GENE_SUMMARY_STATS,
270-
"_gene_mean",
271-
mod_key=modalities.rna_layer,
272-
)
273225

274226
batch_field = fields.MuDataCategoricalObsField(
275227
REGISTRY_KEYS.BATCH_KEY,
@@ -286,7 +238,6 @@ def setup_mudata(
286238
mudata_fields = [
287239
index_field,
288240
batch_field,
289-
gene_field,
290241
fields.MuDataLayerField(
291242
REGISTRY_KEYS.PERTURBATION_KEY,
292243
perturbation_layer,
@@ -346,15 +297,15 @@ def setup_mudata(
346297
@devices_dsp.dedent
347298
def train(
348299
self,
349-
max_epochs: int | None = None,
300+
max_epochs: int = 1000,
350301
accelerator: str = "cpu",
351302
device: int | str = "auto",
352303
train_size: float = 1.0,
353304
validation_size: float | None = None,
354305
shuffle_set_split: bool = False,
355-
batch_size: int = 128,
306+
batch_size: int = 1024,
356307
early_stopping: bool = False,
357-
lr: float | None = None,
308+
lr: float | None = 0.005,
358309
training_plan: PyroTrainingPlan = PyroTrainingPlan,
359310
plan_kwargs: dict | None = None,
360311
data_splitter_kwargs: dict | None = None,

0 commit comments

Comments
 (0)