Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cellij/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from ._data import DataContainer, Importer
from ._factormodel import FactorModel
from ._group import Group
from ._gp import PseudotimeGP

# from ._pyro_guides import Guide, HorseshoeGuide
# from ._pyro_models import MOFA_Model
from ._pyro_models import Generative, HorseshoeGenerative
from .models import MOFA
from .models import MOFA, SimpleGP
from .synthetic import DataGenerator
89 changes: 52 additions & 37 deletions cellij/core/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, *args, **kwargs):
@property
def values(self):
return self._values

@property
def shape(self):
return len(self._merged_obs_names), len(self._merged_feature_names)

@property
def feature_groups(self):
Expand Down Expand Up @@ -115,44 +119,55 @@ def add_data(

def merge_data(self, **kwargs):
"""Merges all feature_groups into a single tensor."""
feature_groups = {}
obs_names = []
for name in self._names:
feature_groups[name] = self._feature_groups[name].to_df()

merged_feature_group = pd.concat(list(feature_groups.values()), axis=1, join='outer')
merged_obs_names = merged_feature_group.index.to_list()
merged_feature_names = merged_feature_group.columns

na_strategy = kwargs.get("na_strategy", None)
if na_strategy is None:
self._values = merged_feature_group.values


if len(self._feature_groups) == 0:
raise ValueError("No data to merge.")

if len(self._feature_groups) == 1:
name = self._names[0]
self._merged_obs_names = self._feature_groups[name].to_df().index.to_list()
self._merged_feature_names = self._feature_groups[name].to_df().columns

else:
self._values = cellij.impute_data(
data=merged_feature_group,
strategy=na_strategy,
kwargs=kwargs,
)

self._merged_obs_names = merged_obs_names
self._merged_feature_names = merged_feature_names

for name in self._names:
feature_group_obs_names = self._feature_groups[name].obs_names.to_list()
feature_group_feature_names = self._feature_groups[name].var_names.to_list()

self._obs_idx[name] = [
i
for i, val in enumerate(merged_obs_names)
if val in feature_group_obs_names
]

self._feature_idx[name] = [
i
for i, val in enumerate(merged_feature_names)
if val in feature_group_feature_names
]

feature_groups = {}
obs_names = []
for name in self._names:
feature_groups[name] = self._feature_groups[name].to_df()

merged_feature_group = pd.concat(list(feature_groups.values()), axis=1, join='outer')
merged_obs_names = merged_feature_group.index.to_list()
merged_feature_names = merged_feature_group.columns

na_strategy = kwargs.get("na_strategy", None)
if na_strategy is None:
self._values = merged_feature_group.values

else:
self._values = cellij.impute_data(
data=merged_feature_group,
strategy=na_strategy,
kwargs=kwargs,
)

self._merged_obs_names = merged_obs_names
self._merged_feature_names = merged_feature_names

for name in self._names:
feature_group_obs_names = self._feature_groups[name].obs_names.to_list()
feature_group_feature_names = self._feature_groups[name].var_names.to_list()

self._obs_idx[name] = [
i
for i, val in enumerate(merged_obs_names)
if val in feature_group_obs_names
]

self._feature_idx[name] = [
i
for i, val in enumerate(merged_feature_names)
if val in feature_group_feature_names
]

def to_df(self) -> pd.DataFrame:
"""Returns a 'pandas.DataFrame' representation of the contained data with feature and observation names."""
Expand Down
160 changes: 135 additions & 25 deletions cellij/core/_factormodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Optional, Union

import anndata
import gpytorch
import muon
import numpy as np
import pandas
import pandas as pd
import pyro
import torch
from pyro.infer import SVI
Expand All @@ -20,9 +21,10 @@
from cellij.core.utils_training import EarlyStopper

logger = logging.getLogger(__name__)
rng = np.random.default_rng()


class FactorModel(PyroModule):
class FactorModel(PyroModule, gpytorch.Module):
"""Base class for all estimators in cellij.

Attributes
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
self._is_trained = False
self._feature_groups = {}
self._obs_groups = {}
self._covariate = None

# Save kwargs for later
self._kwargs = kwargs
Expand Down Expand Up @@ -179,6 +182,102 @@ def obs_groups(self, *args):
"Use `add_obs_group()`, `set_obs_group` or `remove_obs_group()` to modify this property."
)

@property
def covariate(self):
return self._covariate

@covariate.setter
def covariate(self, *args):
raise AttributeError("Use `add_covariate()` to modify this property.")

def add_covariate(
self,
covariate: Union[np.ndarray, pd.DataFrame, pd.Series, torch.Tensor],
num_inducing_points: int = 100,
):
"""
Add a covariate to the model, replacing any existing covariate if necessary.

Parameters
----------
covariate : Any
The covariate to be added to the model. Should be a 1D or 2D matrix-like object.
If 2D, it must not have more than 2 columns.
num_inducing_points : int, optional
The number of inducing points to keep. Default is 100.

Raises
------
TypeError
If the 'covariate' is not a matrix-like object.
ValueError
If the 'covariate' is not a 1D or 2D matrix, or if it is a 2D matrix with more than 2 columns.

Attributes updated
------------------
_covariate : torch.Tensor
The covariate data stored as a PyTorch tensor.
_inducing_points : torch.Tensor
Unique values from the covariate data, stored as a PyTorch tensor.
"""
if self.covariate is not None:
logger.info(
"Currently, only one covariate is supported. Overwriting existing covariate."
)
self.covariate = None

if not isinstance(
covariate, (np.ndarray, pd.DataFrame, pd.Series, torch.Tensor)
):
raise TypeError(
f"Parameter 'covariate' must be a matrix-like object, got {type(covariate)}."
)

if isinstance(covariate, np.ndarray):
covariate = pd.DataFrame(covariate.tolist())
elif isinstance(covariate, torch.Tensor):
covariate = pd.DataFrame(covariate.numpy().tolist())
elif isinstance(covariate, pd.Series):
covariate = covariate.to_frame()

covariate_shape_len = len(covariate.shape)
try:
if covariate_shape_len == 1:
rows = len(covariate)
cols = 1
elif covariate_shape_len == 2:
rows, cols = covariate.shape
else:
raise ValueError(
f"Parameter 'covariate' must be a 1D or 2D matrix, got shape '{covariate.shape}'."
)
except AttributeError as e:
raise TypeError(
f"Paramter 'covariate' must be a matrix-like object, got {type(covariate)}."
) from e

if cols == 0:
raise ValueError(
f"Parameter 'covariate' must have at least one column, got {cols}."
)

if cols > 2:
raise ValueError(
f"Parameter 'covariate' must have 1 or 2 columns, got {cols}."
)

self._covariate = torch.Tensor(covariate.values)

unique_points = covariate.drop_duplicates().values
if len(unique_points) > num_inducing_points:
unique_points = unique_points[
rng.choice(
len(unique_points), size=num_inducing_points, replace=False
)
]
self._inducing_points = torch.Tensor(unique_points)


def _setup_guide(self, guide, kwargs):
if isinstance(guide, str):
# Implement some default guides
Expand Down Expand Up @@ -224,13 +323,13 @@ def _setup_device(self, device):

def add_data(
self,
data: Union[pandas.DataFrame, anndata.AnnData, muon.MuData],
data: Union[pd.DataFrame, anndata.AnnData, muon.MuData],
name: Optional[str] = None,
merge: bool = True,
**kwargs,
):
# TODO: Add a check that no name is "all"
valid_types = (pandas.DataFrame, anndata.AnnData, muon.MuData)

valid_types = (pd.DataFrame, anndata.AnnData, muon.MuData)
metadata = None

if not isinstance(data, valid_types):
Expand All @@ -243,13 +342,9 @@ def add_data(
"When adding data that is not a MuData object, a name must be provided."
)

if isinstance(data, pandas.DataFrame):
data = anndata.AnnData(
X=data.values,
obs=pandas.DataFrame(data.index),
var=pandas.DataFrame(data.columns),
dtype=self._dtype,
)
if isinstance(data, pd.DataFrame):
data = anndata.AnnData(data)
self._add_data(data=data, name=name)

elif isinstance(data, anndata.AnnData):
self._add_data(data=data, name=name)
Expand Down Expand Up @@ -419,8 +514,8 @@ def _get_from_param_storage(

if key not in list(pyro.get_param_store().keys()):
raise ValueError(
f"Parameter '{key}' not found in parameter storage. "
f"Available choices are: {', '.join(list(pyro.get_param_store().keys()))}"
f"Parameter '{key}' not found in parameter storage. ",
f"Available choices are: {', '.join(list(pyro.get_param_store().keys()))}",
)

data = pyro.get_param_store()[key]
Expand Down Expand Up @@ -545,18 +640,34 @@ def fit(
for k, feature_idx in self._data._feature_idx.items()
}

if self.covariate is not None:
self.gp = cellij.core._gp.PseudotimeGP(
inducing_points=self._inducing_points, n_factors=self.n_factors
)

self._kwargs["gp"] = self.gp
self._kwargs["covariate"] = self.covariate

# Initialize class objects with correct data-related parameters
if not self._is_trained:
self._model = self._model(
n_samples=self._data._values.shape[0],
n_factors=self.n_factors,
feature_dict=feature_dict,
likelihoods=None,
device=self.device,
**self._kwargs,
)
self._model = self._model(
n_samples=len(self._data._merged_obs_names),
n_factors=self.n_factors,
feature_dict=feature_dict,
likelihoods=None,
device=self.device,
**self._kwargs,
)

# for key, value in self._kwargs.items():
# if key in ["init_loc", "init_scale"]:
# self._guide_kwargs[key] = value

if self.covariate is not None:
self._guide_kwargs["gp"] = self.gp
self._guide_kwargs["covariate"] = self.covariate

self._guide = self._guide(self._model, **self._guide_kwargs)
self._guide = self._guide(self._model, **self._guide_kwargs)

if not isinstance(likelihoods, (str, dict)):
raise ValueError(
Expand All @@ -581,7 +692,7 @@ def fit(
# models/datasets
scaling_constant = 1.0
if scale_gradients:
scaling_constant = 1.0 / self._data.values.shape[1]
scaling_constant = 1.0 / self._data.shape[1]

optim = pyro.optim.Adam({"lr": learning_rate, "betas": (0.95, 0.999)})
if optimizer.lower() == "clipped":
Expand All @@ -593,7 +704,6 @@ def fit(
model=pyro.poutine.scale(self._model, scale=scaling_constant),
guide=pyro.poutine.scale(self._guide, scale=scaling_constant),
optim=optim,
# loss=pyro.infer.Trace_ELBO(),
loss=pyro.infer.Trace_ELBO(
retain_graph=True,
num_particles=num_particles,
Expand Down
37 changes: 37 additions & 0 deletions cellij/core/_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import gpytorch
import torch


class PseudotimeGP(gpytorch.models.ApproximateGP):
def __init__(
self,
inducing_points: torch.Tensor,
n_factors: int,
init_lengthscale=5.0,
) -> None:
n_inducing = len(inducing_points)

variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=n_inducing,
batch_shape=torch.Size([n_factors]),
)

variational_strategy = gpytorch.variational.VariationalStrategy(
model=self,
inducing_points=inducing_points,
variational_distribution=variational_distribution,
learn_inducing_locations=False,
)

super().__init__(variational_strategy=variational_strategy)
self.mean_module = gpytorch.means.ZeroMean(
batch_shape=torch.Size([n_factors]),
)
self.kernel = gpytorch.kernels.RBFKernel(batch_shape=torch.Size([n_factors]))
self.covar_module = gpytorch.kernels.ScaleKernel(self.kernel)
self.covar_module.base_kernel.lengthscale = torch.tensor(init_lengthscale)

def forward(self, x) -> gpytorch.distributions.MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
Loading