diff --git a/gpytorch/utils/sum_interaction_terms.py b/gpytorch/utils/sum_interaction_terms.py index be7462132..ed9b001a6 100644 --- a/gpytorch/utils/sum_interaction_terms.py +++ b/gpytorch/utils/sum_interaction_terms.py @@ -2,16 +2,15 @@ import torch -from jaxtyping import Float from linear_operator import LinearOperator, to_dense from torch import Tensor def sum_interaction_terms( - covars: Float[Union[LinearOperator, Tensor], "... D N N"], + covars: Union[LinearOperator, Tensor], # shape: (..., D, N, N) max_degree: Optional[int] = None, dim: int = -3, -) -> Float[Tensor, "... N N"]: +) -> Tensor: # shape: (..., N, N) r""" Given a batch of D x N x N covariance matrices :math:`\boldsymbol K_1, \ldots, \boldsymbol K_D`, compute the sum of each covariance matrix as well as the interaction terms up to degree `max_degree` diff --git a/gpytorch/variational/nearest_neighbor_variational_strategy.py b/gpytorch/variational/nearest_neighbor_variational_strategy.py index 7b95a2c3f..1346cb388 100644 --- a/gpytorch/variational/nearest_neighbor_variational_strategy.py +++ b/gpytorch/variational/nearest_neighbor_variational_strategy.py @@ -3,7 +3,6 @@ from typing import Any, Optional import torch -from jaxtyping import Float from linear_operator import to_dense from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator from linear_operator.utils.cholesky import psd_safe_cholesky @@ -77,8 +76,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy): def __init__( self, model: ApproximateGP, - inducing_points: Float[Tensor, "... M D"], - variational_distribution: Float[_VariationalDistribution, "... M"], + inducing_points: Tensor, # shape: (..., M, D) + variational_distribution: _VariationalDistribution, # shape: (..., M) k: int, training_batch_size: Optional[int] = None, jitter_val: Optional[float] = 1e-3, @@ -120,21 +119,26 @@ def __init__( @property @cached(name="prior_distribution_memo") - def prior_distribution(self) -> Float[MultivariateNormal, "... M"]: + def prior_distribution(self) -> MultivariateNormal: # shape: (..., M) out = self.model.forward(self.inducing_points) res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val)) return res def _cholesky_factor( - self, induc_induc_covar: Float[LinearOperator, "... M M"] - ) -> Float[TriangularLinearOperator, "... M M"]: + self, + induc_induc_covar: LinearOperator, # shape: (..., M, M) + ) -> TriangularLinearOperator: # shape: (..., M, M) # Uncached version L = psd_safe_cholesky(to_dense(induc_induc_covar)) return TriangularLinearOperator(L) def __call__( - self, x: Float[Tensor, "... N D"], prior: bool = False, diag: bool = True, **kwargs: Any - ) -> Float[MultivariateNormal, "... N"]: + self, + x: Tensor, # shape: (..., N, D) + prior: bool = False, + diag: bool = True, + **kwargs: Any, + ) -> MultivariateNormal: # shape: (..., N) # If we're in prior mode, then we're done! if prior: return self.model.forward(x, **kwargs) @@ -176,13 +180,13 @@ def __call__( def forward( self, - x: Float[Tensor, "... N D"], - inducing_points: Float[Tensor, "... M D"], - inducing_values: Float[Tensor, "... M"], - variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None, + x: Tensor, # shape: (..., N, D) + inducing_points: Tensor, # shape: (..., M, D) + inducing_values: Tensor, # shape: (..., M) + variational_inducing_covar: Optional[LinearOperator] = None, # shape: (..., M, M) diag: bool = True, **kwargs: Any, - ) -> Float[MultivariateNormal, "... N"]: + ) -> MultivariateNormal: # shape: (..., N) # TODO: This method needs to return the full covariance in eval mode, not just the predictive variance. # TODO: Use `diag` to control when to compute the variance vs. covariance in train mode. if self.training: @@ -281,8 +285,8 @@ def forward( def get_fantasy_model( self, - inputs: Float[Tensor, "... N D"], - targets: Float[Tensor, "... N"], + inputs: Tensor, # shape: (..., N, D) + targets: Tensor, # shape: (..., N) mean_module: Optional[Module] = None, covar_module: Optional[Module] = None, **kwargs, @@ -312,7 +316,7 @@ def _get_training_indices(self) -> LongTensor: self._set_training_iterator() return self.current_training_indices - def _firstk_kl_helper(self) -> Float[Tensor, "..."]: + def _firstk_kl_helper(self) -> Tensor: # shape: (...) # Compute the KL divergence for first k inducing points train_x_firstk = self.inducing_points[..., : self.k, :] full_output = self.model.forward(train_x_firstk) @@ -330,7 +334,10 @@ def _firstk_kl_helper(self) -> Float[Tensor, "..."]: kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape return kl - def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821 + def _stochastic_kl_helper( + self, + kl_indices: Tensor, # shape: (n_batch,) + ) -> Tensor: # shape: (...) # Compute the KL divergence for a mini batch of the rest M-k inducing points # See paper appendix for kl breakdown kl_bs = len(kl_indices) # training_batch_size @@ -435,7 +442,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T def _kl_divergence( self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None - ) -> Float[Tensor, "..."]: + ) -> Tensor: # shape: (...) if self.compute_full_kl or (self._total_training_batches == 1): if batch_size is None: batch_size = self.training_batch_size @@ -455,7 +462,7 @@ def _kl_divergence( kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices) return kl - def kl_divergence(self) -> Float[Tensor, "..."]: + def kl_divergence(self) -> Tensor: # shape: (...) try: return pop_from_cache(self, "kl_divergence_memo") except CachingError: