Prediction strategies help #2184
Replies: 2 comments 4 replies
-
Hi @alexpeters1208 - we've been wanting to add Vecchia models to GPyTorch for a while. I have my own implementation that I've been meaning to merge in, but I'm also curious to see your implementation. The strategy I used was not to instantiate a new prediction strategy, but instead to create a batch of GPs (each of which makes a prediction on a single data point). We set the training data for each GP in the batch to be the nearest neighbors of the target data point. This approach makes predictions very parallelizable. I'll attach it below. |
Beta Was this translation helpful? Give feedback.
-
import argparse
import numpy as np
import torch
import gpytorch
import faiss
import math
import os
import time
from util import sample_batch_indices
class VecchiaModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood, k=16):
# Placeholder mask
mask = train_x.new_ones(*train_x.shape[:-1], 1)
super().__init__((train_x, mask), train_y, likelihood)
self.k = k
self.res = faiss.StandardGpuResources()
self.register_buffer("train_x", train_x)
self.register_buffer("train_indices", torch.arange(len(train_x)))
self.register_buffer("train_nn_indices", torch.zeros(len(train_x), k, dtype=torch.long))
self.register_buffer("train_nn_mask", torch.zeros(len(train_x), k, dtype=torch.bool))
self.likelihood = likelihood
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.MaternKernel(nu=2.5, ard_num_dims=train_x.size(-1))
)
self.nn_time = 0
def compute_train_nn_indices(self):
assert self.k > 0
start = time.time()
# Create ordering based on 1st PCA vector
with torch.no_grad():
x = (self.train_x.data.float() / self.covar_module.base_kernel.lengthscale.data.float()).cpu().numpy()
# Create ordering based on first PCA vector
mat = faiss.PCAMatrix(x.shape[-1], 1)
mat.train(x)
assert mat.is_trained
projection = torch.from_numpy(mat.apply_py(x)).squeeze(-1).cuda()
self.train_indices = projection.argsort()
# Construct masked nearest neighbor set based on ordering
self.cpu_index = faiss.IndexFlatL2(self.train_x.size(-1))
self.gpu_index = faiss.index_cpu_to_gpu(self.res, 0, self.cpu_index)
for i, index in enumerate(self.train_indices.tolist()):
row = x[index][None, :]
self.gpu_index.add(row)
self.train_nn_indices[index].copy_(
torch.from_numpy(self.gpu_index.search(row, self.k + 1)[1][..., 0, 1:]).long().to(self.train_x.device)
)
self.train_nn_mask[index, :min(i, self.k)] = True
self.nn_time += (time.time() - start)
def compute_test_nn_indices(self, x):
with torch.no_grad():
train_x = (self.train_x.data.float() / self.covar_module.base_kernel.lengthscale.data.float()).cpu().numpy()
self.cpu_index = faiss.IndexFlatL2(self.train_x.size(-1))
self.gpu_index = faiss.index_cpu_to_gpu(self.res, 0, self.cpu_index)
self.gpu_index.add(train_x)
x_np = (x.data.float() / self.covar_module.base_kernel.lengthscale.data.float()).cpu().numpy()
return torch.from_numpy(self.gpu_index.search(x_np, self.k)[1]).long().to(x.device)
def forward(self, x, mask):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x).evaluate()
# Replace masked-out entries with an identity matrix
eye = torch.eye(covar_x.size(-1), dtype=covar_x.dtype, device=covar_x.device)
covar_x = torch.where(mask & mask.transpose(-1, -2), covar_x, eye)
return gpytorch.distributions.MultivariateNormal(mean_x, gpytorch.lazify(covar_x))
def main(train_x, train_y, test_x, test_y, **args):
N_train = train_x.size(0)
N_test = test_x.size(0)
print("N_train: {} N_test: {} D: {}".format(N_train, N_test, train_x.size(-1)))
likelihood = gpytorch.likelihoods.GaussianLikelihood().cuda()
model = VecchiaModel(train_x, train_y, likelihood, k=args.k).cuda()
optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=args.lr, betas=(0.90, 0.999))
milestones = [int(k * args.num_iter) for k in [0.25, 0.5, 0.75]]
sched = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2)
num_iter = args.num_iter
report_freq = args.report_freq
model.eval() # We want eval mode, because the MLL is composed of Gaussian conditionals
times = [time.time()]
losses = []
all_indices = []
model.compute_train_nn_indices()
for iteration in range(num_iter):
if (iteration % args.nn_freq == 0) and iteration > 0:
model.compute_train_nn_indices()
if len(all_indices) == 0:
all_indices = sample_batch_indices(N_train, args.mini_batch_size)
mini_batch_indices = all_indices.pop()
nn_indices = model.train_nn_indices[mini_batch_indices]
nn_mask = model.train_nn_mask[mini_batch_indices]
x_batch = train_x[nn_indices]
y_batch = train_y[nn_indices]
# We want batch GP that can handle different amounts of training data
# For batches with < k training data point, we will:
# - forward pass of model will replace masked-out parts of covar with identity
# These large-value/zero points should have no effect on the rest of the data
model.set_train_data((x_batch, nn_mask[..., None]), y_batch, strict=False)
optimizer.zero_grad()
# Compute the predictive distribution of each y to get the MLL factors
with gpytorch.settings.detach_test_caches(False):
pred_x = train_x[mini_batch_indices][..., None, :]
pred_y = train_y[mini_batch_indices][..., None]
pred_mask = torch.ones(*pred_y.shape, 1, dtype=torch.bool, device=train_x.device)
output = likelihood(model(pred_x, pred_mask))
output = torch.distributions.Normal(output.mean, output.stddev)
log_probs = output.log_prob(pred_y)
loss = -log_probs.squeeze(dim=-1).mean(dim=-1)
loss.backward()
optimizer.step()
sched.step()
losses.append(loss.item())
times.append(time.time())
if iteration >= report_freq and ((iteration + 1) % report_freq == 0 or iteration == (num_iter - 1)):
dt = (times[-1] - times[-1 - report_freq]) / report_freq
lengthscale = model.covar_module.base_kernel.lengthscale
print('Iter %d/%d - Loss: %.3f %.3f lengthscale: %.3f %.3f %.3f sigma: %.3f os: %.3f [dt: %.3f]' % (
iteration + 1, num_iter, losses[-1], np.mean(losses[-report_freq:]),
lengthscale.mean().item(), lengthscale.min().item(), lengthscale.max().item(),
model.likelihood.noise.sqrt().item(),
model.covar_module.outputscale.sqrt().item(), dt))
print("Total Training Time: %.4f" % (times[-1] - times[0]))
print("Total NN Time: %.4f" % model.nn_time)
model.set_train_data(train_x, train_y, strict=False)
# NN posterior
test_mse = 0.
test_nll = 0.
with torch.no_grad(), gpytorch.settings.fast_pred_var():
for x_batch, y_batch in zip(test_x.split(512), test_y.split(512)):
# Clear cache
model.train()
model.eval()
# Compute the NN posterior
nn_indices = model.compute_test_nn_indices(x_batch)
train_mask = torch.ones(*nn_indices.shape, 1, dtype=torch.bool, device=train_x.device)
model.set_train_data((train_x[nn_indices, :], train_mask), train_y[nn_indices], strict=False)
pred_mask = torch.ones(*y_batch.shape, 1, 1, dtype=torch.bool, device=train_x.device)
posterior = model.likelihood(model(x_batch.unsqueeze(-2), pred_mask)).to_data_independent_dist()
# Computet stats
test_nll -= posterior.log_prob(y_batch.unsqueeze(-1)).squeeze(-1).sum(dim=-1).item()
test_mse += (y_batch.unsqueeze(-1) - posterior.loc).pow(2.0).squeeze(-1).sum(dim=-1).item()
# Reset training data, clear cache
model.train()
model.eval()
model.set_train_data(train_x, train_y, strict=False)
# Aggregate stats
test_rmse = math.sqrt(test_mse / len(test_y))
test_nll = test_nll / len(test_y)
s = "[Seed {}: {} {}] - NN posterior TEST LL: {:.3f} RMSE: {:.3f} CRPS: {:.3f}"
print(s.format(args.seed, "vecchia", args.dataset, -test_nll, test_rmse, test_crps)) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello all,
I've been working with GPyTorch for a while and have developed a modular implementation of non-variational nearest-neighbors models in GPyTorch that I'm calling Vecchia models. This includes the nearest-neighbor Gaussian process, the block independent Gaussian process, and the block nearest-neighbors Gaussian process. Currently, I am able to train these different kinds of models with no problems and very little modification to the way that users define GPyTorch models.
However, predicting is proving more difficult. It seems that the prediction strategies implemented in GPyTorch, in both exact and variational context, make use of the full joint distribution of the testing and training points to make predictions. However, predictions in the context of Vecchia models are understood and explicitly expressed in terms of the conditional distributions of testing points given training points. Reverse engineering these conditional distributions to derive the full joint distributions may be possible, but would lose all of the computational gain of these kinds of approximations.
So, what is the recommended way to implement this kind of prediction? I could implement a new class of model with this prediction strategy used by default (IE exact_gp, approximate_gp, and vecchia_gp). With this solution, I worry that I will lose the modularity that I believe my current implementation has (Vecchia approximation can be used in concert with variational methods - not yet tested). I could create a new prediction_strategy and insist that users subclass ApproximateGP, but I again worry that I will then be unable to combine Vecchia approximations with variational methods and Deep GP's. Maybe I could write a class to wrap any existing prediction strategy and make predictions based on the conditional distributions rather than the joint distribution. This doesn't feel like a great solution either.
Thanks in advance for your time. I hope to contribute these new features to GPyTorch when they're complete and tested, because I think they will add a valuable new class of models to the wide variety of existing capabilities of GPyTorch.
Beta Was this translation helpful? Give feedback.
All reactions