Skip to content
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
520 changes: 520 additions & 0 deletions examples/03_Multitask_Exact_GPs/GP_Factor_Analysis_Regression.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions examples/03_Multitask_Exact_GPs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Multi-output (vector valued functions)
Multitask_GP_Regression.ipynb
Batch_Independent_Multioutput_GP.ipynb
ModelList_GP_Regression.ipynb
GP_Factor_Analysis_Regression.ipynb

Scalar function with multiple tasks
----------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .cylindrical_kernel import CylindricalKernel
from .distributional_input_kernel import DistributionalInputKernel
from .gaussian_symmetrized_kl_kernel import GaussianSymmetrizedKLKernel
from .gpfa_component_kernel import GPFAComponentKernel
from .gpfa_kernel import GPFAKernel
from .grid_interpolation_kernel import GridInterpolationKernel
from .grid_kernel import GridKernel
from .index_kernel import IndexKernel
Expand Down Expand Up @@ -42,6 +44,7 @@
"GaussianSymmetrizedKLKernel",
"GridKernel",
"GridInterpolationKernel",
"GPFAKernel",
"IndexKernel",
"InducingPointKernel",
"LCMKernel",
Expand All @@ -57,6 +60,7 @@
"RBFKernel",
"RFFKernel",
"RBFKernelGrad",
"GPFAComponentKernel",
"RQKernel",
"ScaleKernel",
"SpectralDeltaKernel",
Expand Down
50 changes: 50 additions & 0 deletions gpytorch/kernels/gpfa_component_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3

import torch

from ..lazy import DiagLazyTensor, KroneckerProductLazyTensor, lazify
from .kernel import Kernel


class GPFAComponentKernel(Kernel):
r"""
Kernel supporting Gaussian Process Factor Analysis using
:class:`gpytorch.kernels.GPFAComponentKernel` as a basic GPFA latent kernel.


Given a base covariance module to be used for a latent, :math:`K_{XX}`, this kernel computes a latent kernel of
specified size :math:`K_MM}` that is zeros everywhere except :math:`K_{kernel_loc,kernel_loc}` and returns
:math:`K = K_{MM} \otimes K_{XX}`. as an :obj:`gpytorch.lazy.KroneckerProductLazyTensor`.

:param ~gpytorch.kernels.Kernel data_covar_module: Kernel to use as the latent kernel.
:param int num_latents: Number of latents (M)
:param int kernel_loc: Latent number that this kernel represents.
:param dict kwargs: Additional arguments to pass to the kernel.
"""

def __init__(self, data_covar_module, num_latents, kernel_loc, **kwargs):
"""
"""
super(GPFAComponentKernel, self).__init__(**kwargs)
task_diag = torch.zeros(num_latents)
task_diag[kernel_loc] = 1
self.task_covar = DiagLazyTensor(task_diag)
self.data_covar_module = data_covar_module
self.num_latents = num_latents

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
if last_dim_is_batch:
raise RuntimeError("GPFAComponentKernel does not accept the last_dim_is_batch argument.")
covar_i = self.task_covar
if len(x1.shape[:-2]):
covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
covar_x = lazify(self.data_covar_module.forward(x1, x2, **params))
res = KroneckerProductLazyTensor(covar_x, covar_i)
return res.diag() if diag else res

def num_outputs_per_input(self, x1, x2):
"""
Given `n` data points `x1` and `m` datapoints `x2`, this
kernel returns an `(n*num_latents) x (m*num_latents)` covariance matrix.
"""
return self.num_latents
75 changes: 75 additions & 0 deletions gpytorch/kernels/gpfa_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this kernel seems really similar to the LCM Kernel that we already have in GPyTorch. Is there any way we can integrate the two?

I know that you want to be able to do posterior inference on the latent function (currently not totally possible with GPyTorch), and I'm not sure if this is possible with the current LCMKernel implementation. However, I'm afraid that having two similar kernels in the library will be difficult and confusing to maintain.


from copy import deepcopy

import torch

from ..lazy import DiagLazyTensor, KroneckerProductLazyTensor
from .gpfa_component_kernel import GPFAComponentKernel
from .kernel import AdditiveKernel, Kernel


class GPFAKernel(Kernel):
r"""
Kernel supporting Gaussian Process Factor Analysis using
:class:`gpytorch.kernels.GPFAComponentKernel` as a basic GPFA latent kernel.

Given base covariance modules to be used for the latents, :math:`k_i`, this kernel
puts the base covariance modules in a block diagonal with :math:`M` blocks as :math:`K_{XX}`.
This defines :math:`C \in MxN` and returns :math:`(I_T \otimes C)K_{XX}(I_T \otimes C)^T` as an
:obj:`gpytorch.lazy.LazyEvaluatedKernelTensor`.

:param ~gpytorch.kernels.Kernel data_covar_module: Kernel to use as the latent kernel.
:param int num_latents: Number of latents (M).
:param int num_obs: Number of observation dimensions (typically, the number of neurons, N).
:param ~gpytorch.kernels.Kernel GPFA_component: (default GPFAComponentKernel) Kernel to use to scale the latent
kernels to the necessary shape.
GPFAComponentKernel is currently the only option; if non-reversible kernels are later added,
there will then be another option here.
:param dict kwargs: Additional arguments to pass to the kernel.
"""

def __init__(
self, data_covar_modules, num_latents, num_obs, GPFA_component=GPFAComponentKernel, **kwargs,
):
super(GPFAKernel, self).__init__(**kwargs)
self.num_obs = num_obs
self.num_latents = num_latents

if not isinstance(data_covar_modules, list) or len(data_covar_modules) == 1:
if isinstance(data_covar_modules, list):
data_covar_modules = data_covar_modules[0]
data_covar_modules = [deepcopy(data_covar_modules) for i in range(num_latents)]

self.latent_covar_module = AdditiveKernel(
*[GPFA_component(data_covar_modules[i], num_latents, i) for i in range(num_latents)]
)
self.register_parameter(name="raw_C", parameter=torch.nn.Parameter(torch.randn(num_obs, num_latents)))

@property
def C(self):
return self.raw_C

@C.setter
def C(self, value):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_C)

self.initialize(raw_C=value)

def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
if last_dim_is_batch:
raise RuntimeError("GPFAKernel does not yet accept the last_dim_is_batch argument.")
I_t1 = DiagLazyTensor(torch.ones(len(x1)))
I_t2 = DiagLazyTensor(torch.ones(len(x2)))
kron_prod_1 = KroneckerProductLazyTensor(I_t1, self.C)
kron_prod_2 = KroneckerProductLazyTensor(I_t2, self.C)
covar = kron_prod_1 @ self.latent_covar_module(x1, x2, **params) @ kron_prod_2.t()
return covar.diag() if diag else covar

def num_outputs_per_input(self, x1, x2):
"""`
Given `n` data points `x1` and `m` datapoints `x2`, this
kernel returns an `(n*num_obs) x (m*num_obs)` covariance matrix.
"""
return self.num_obs
215 changes: 215 additions & 0 deletions test/examples/test_gpfa_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#!/usr/bin/env python3

import os
import random
import unittest
import numpy as np

import gpytorch
import torch
from gpytorch.kernels import GPFAKernel
from gpytorch.lazy import DiagLazyTensor, KroneckerProductLazyTensor, KroneckerProductDiagLazyTensor


def generate_GPFA_Data(seed=0,
n_timepoints=100,
n_latents=2,
num_obs=50,
length_scales=[.01, 10],
start_time=-5,
end_time=5,
zero_mean=True):
torch.manual_seed(seed)
np.random.seed(seed)
timepoints = torch.linspace(start_time, end_time, n_timepoints)
tau = torch.tensor(length_scales)
C = torch.tensor(
np.random.normal(scale=1. / np.sqrt(n_latents),
size=(num_obs, n_latents))).float()
if zero_mean:
d = torch.zeros(size=(num_obs, 1)).float()
else:
d = torch.tensor(np.random.uniform(size=(num_obs, 1))).float()
R = torch.tensor(np.diag(np.random.uniform(size=(num_obs, ),
low=.1))).float()
kernels = [gpytorch.kernels.RBFKernel() for t in tau]
for t in range(len(tau)):
kernels[t].lengthscale = tau[t]

xs = torch.stack([
gpytorch.distributions.MultivariateNormal(torch.zeros(n_timepoints),
k(timepoints,
timepoints)).sample()
for k in kernels
])

ys = gpytorch.distributions.MultivariateNormal((C @ xs + d).T, R).sample()

xs = xs.T.contiguous()

return timepoints, tau, C, d, R, kernels, xs, ys


n_latents = 2
num_obs = 20
n_timepoints = 100
length_scales = [.1, .2]
start_time = 0
end_time = 1
train_x, tau, C, d, R, kernels, xs, train_y = generate_GPFA_Data(
seed=10,
n_timepoints=n_timepoints,
n_latents=n_latents,
num_obs=num_obs,
length_scales=length_scales,
start_time=start_time,
end_time=end_time,
zero_mean=True)

# For now, just test that GPFA more or less recovers the noiseless version of train_y
test_y = (C @ xs.T).T


class GPFAModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, latent_covar_modules,
num_latents, num_obs):
super(GPFAModel, self).__init__(train_x, train_y, likelihood)

self.num_latents = num_latents
self.num_obs = num_obs

self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ZeroMean(), num_tasks=num_obs)
self.covar_module = GPFAKernel(latent_covar_modules, num_latents,
num_obs)

def forward(self, x):
return gpytorch.distributions.MultitaskMultivariateNormal(
self.mean_module(x), self.covar_module(x))

# Not currently used in this test. TODO: Test recovery of latents
def latent_posterior(self, x):
r'''
See equations 4 and 5 in `Non-reversible Gaussian processes for
identifying latent dynamical structure in neural data`_

.. _Non-reversible Gaussian processes for identifying latent dynamical structure in neural data:
https://papers.nips.cc/paper/2020/file/6d79e030371e47e6231337805a7a2685-Paper.pdf
'''
I_t = DiagLazyTensor(torch.ones(len(x)))
combined_noise = (self.likelihood.task_noises if self.likelihood.has_task_noise
else torch.zeros(self.likelihood.num_tasks)) + (
self.likelihood.noise
if self.likelihood.has_global_noise else 0)
Kyy = self.covar_module(x) + KroneckerProductDiagLazyTensor(
I_t, DiagLazyTensor(combined_noise))

Kxx = self.covar_module.latent_covar_module(x)

C_Kron_I = KroneckerProductLazyTensor(I_t, self.covar_module.C)

mean_rhs = (train_y - self.mean_module(x)).view(
*(train_y.numel(),
)) # vertically stacks after doing the subtraction

latent_mean = Kxx @ C_Kron_I.t() @ Kyy.inv_matmul(mean_rhs)
latent_mean = latent_mean.view(*(len(x),
int(latent_mean.shape[0] / len(x))))

cov_rhs = C_Kron_I @ Kxx
latent_cov = Kxx - Kxx @ C_Kron_I.t() @ Kyy.inv_matmul(
cov_rhs.evaluate())
return gpytorch.distributions.MultitaskMultivariateNormal(
latent_mean, latent_cov)


class TestGPFARegression(unittest.TestCase):
def setUp(self):
if os.getenv("UNLOCK_SEED") is None or os.getenv(
"UNLOCK_SEED").lower() == "false":
self.rng_state = torch.get_rng_state()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
random.seed(0)

def tearDown(self):
if hasattr(self, "rng_state"):
torch.set_rng_state(self.rng_state)

def test_multitask_gp_mean_abs_error(self):
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
num_tasks=num_obs)
kernels = [gpytorch.kernels.RBFKernel() for t in range(n_latents)]
model = GPFAModel(train_x, train_y, likelihood, kernels, n_latents,
num_obs)
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=0.1) # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

n_iter = 50
for _ in range(n_iter):
# Zero prev backpropped gradients
optimizer.zero_grad()
# Make predictions from training data
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()

# Test the model predictions on noiseless test_ys
with torch.no_grad(), gpytorch.settings.fast_pred_var():
model.eval()
likelihood.eval()
preds = likelihood(model(train_x))
pred_mean = preds.mean
mean_abs_error = torch.mean(torch.abs(test_y - pred_mean), axis=0)
self.assertFalse(torch.sum(mean_abs_error > (2 * torch.diagonal(R))))

def test_multitask_gp_mean_abs_error_one_kernel(self):
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
num_tasks=num_obs)
model = GPFAModel(train_x, train_y, likelihood, gpytorch.kernels.RBFKernel(), n_latents,
num_obs)
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(
model.parameters(),
lr=0.1) # Includes GaussianLikelihood parameters

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

n_iter = 50
for _ in range(n_iter):
# Zero prev backpropped gradients
optimizer.zero_grad()
# Make predictions from training data
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()

# Test the model predictions on noiseless test_ys
with torch.no_grad(), gpytorch.settings.fast_pred_var():
model.eval()
likelihood.eval()
preds = likelihood(model(train_x))
pred_mean = preds.mean
mean_abs_error = torch.mean(torch.abs(test_y - pred_mean), axis=0)
self.assertFalse(torch.sum(mean_abs_error > (2 * torch.diagonal(R))))


if __name__ == "__main__":
unittest.main()