Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions gpytorch/utils/sum_interaction_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

import torch

from jaxtyping import Float
from linear_operator import LinearOperator, to_dense
from torch import Tensor


def sum_interaction_terms(
covars: Float[Union[LinearOperator, Tensor], "... D N N"],
covars: Union[LinearOperator, Tensor], # shape: (..., D, N, N)
max_degree: Optional[int] = None,
dim: int = -3,
) -> Float[Tensor, "... N N"]:
) -> Tensor: # shape: (..., N, N)
r"""
Given a batch of D x N x N covariance matrices :math:`\boldsymbol K_1, \ldots, \boldsymbol K_D`,
compute the sum of each covariance matrix as well as the interaction terms up to degree `max_degree`
Expand Down
45 changes: 26 additions & 19 deletions gpytorch/variational/nearest_neighbor_variational_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Optional

import torch
from jaxtyping import Float
from linear_operator import to_dense
from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator
from linear_operator.utils.cholesky import psd_safe_cholesky
Expand Down Expand Up @@ -77,8 +76,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy):
def __init__(
self,
model: ApproximateGP,
inducing_points: Float[Tensor, "... M D"],
variational_distribution: Float[_VariationalDistribution, "... M"],
inducing_points: Tensor, # shape: (..., M, D)
variational_distribution: _VariationalDistribution, # shape: (..., M)
k: int,
training_batch_size: Optional[int] = None,
jitter_val: Optional[float] = 1e-3,
Expand Down Expand Up @@ -120,21 +119,26 @@ def __init__(

@property
@cached(name="prior_distribution_memo")
def prior_distribution(self) -> Float[MultivariateNormal, "... M"]:
def prior_distribution(self) -> MultivariateNormal: # shape: (..., M)
out = self.model.forward(self.inducing_points)
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
return res

def _cholesky_factor(
self, induc_induc_covar: Float[LinearOperator, "... M M"]
) -> Float[TriangularLinearOperator, "... M M"]:
self,
induc_induc_covar: LinearOperator, # shape: (..., M, M)
) -> TriangularLinearOperator: # shape: (..., M, M)
# Uncached version
L = psd_safe_cholesky(to_dense(induc_induc_covar))
return TriangularLinearOperator(L)

def __call__(
self, x: Float[Tensor, "... N D"], prior: bool = False, diag: bool = True, **kwargs: Any
) -> Float[MultivariateNormal, "... N"]:
self,
x: Tensor, # shape: (..., N, D)
prior: bool = False,
diag: bool = True,
**kwargs: Any,
) -> MultivariateNormal: # shape: (..., N)
# If we're in prior mode, then we're done!
if prior:
return self.model.forward(x, **kwargs)
Expand Down Expand Up @@ -176,13 +180,13 @@ def __call__(

def forward(
self,
x: Float[Tensor, "... N D"],
inducing_points: Float[Tensor, "... M D"],
inducing_values: Float[Tensor, "... M"],
variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
x: Tensor, # shape: (..., N, D)
inducing_points: Tensor, # shape: (..., M, D)
inducing_values: Tensor, # shape: (..., M)
variational_inducing_covar: Optional[LinearOperator] = None, # shape: (..., M, M)
diag: bool = True,
**kwargs: Any,
) -> Float[MultivariateNormal, "... N"]:
) -> MultivariateNormal: # shape: (..., N)
# TODO: This method needs to return the full covariance in eval mode, not just the predictive variance.
# TODO: Use `diag` to control when to compute the variance vs. covariance in train mode.
if self.training:
Expand Down Expand Up @@ -281,8 +285,8 @@ def forward(

def get_fantasy_model(
self,
inputs: Float[Tensor, "... N D"],
targets: Float[Tensor, "... N"],
inputs: Tensor, # shape: (..., N, D)
targets: Tensor, # shape: (..., N)
mean_module: Optional[Module] = None,
covar_module: Optional[Module] = None,
**kwargs,
Expand Down Expand Up @@ -312,7 +316,7 @@ def _get_training_indices(self) -> LongTensor:
self._set_training_iterator()
return self.current_training_indices

def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
def _firstk_kl_helper(self) -> Tensor: # shape: (...)
# Compute the KL divergence for first k inducing points
train_x_firstk = self.inducing_points[..., : self.k, :]
full_output = self.model.forward(train_x_firstk)
Expand All @@ -330,7 +334,10 @@ def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
return kl

def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821
def _stochastic_kl_helper(
self,
kl_indices: Tensor, # shape: (n_batch,)
) -> Tensor: # shape: (...)
# Compute the KL divergence for a mini batch of the rest M-k inducing points
# See paper appendix for kl breakdown
kl_bs = len(kl_indices) # training_batch_size
Expand Down Expand Up @@ -435,7 +442,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T

def _kl_divergence(
self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None
) -> Float[Tensor, "..."]:
) -> Tensor: # shape: (...)
if self.compute_full_kl or (self._total_training_batches == 1):
if batch_size is None:
batch_size = self.training_batch_size
Expand All @@ -455,7 +462,7 @@ def _kl_divergence(
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
return kl

def kl_divergence(self) -> Float[Tensor, "..."]:
def kl_divergence(self) -> Tensor: # shape: (...)
try:
return pop_from_cache(self, "kl_divergence_memo")
except CachingError:
Expand Down