Skip to content

Commit 67f59e4

Browse files
authored
Merge pull request #2711 from Balandat/remove_jaxtyping
Remove jaxtyping from gpytorch
2 parents 07fa44d + f075143 commit 67f59e4

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

gpytorch/utils/sum_interaction_terms.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
import torch
44

5-
from jaxtyping import Float
65
from linear_operator import LinearOperator, to_dense
76
from torch import Tensor
87

98

109
def sum_interaction_terms(
11-
covars: Float[Union[LinearOperator, Tensor], "... D N N"],
10+
covars: Union[LinearOperator, Tensor], # shape: (..., D, N, N)
1211
max_degree: Optional[int] = None,
1312
dim: int = -3,
14-
) -> Float[Tensor, "... N N"]:
13+
) -> Tensor: # shape: (..., N, N)
1514
r"""
1615
Given a batch of D x N x N covariance matrices :math:`\boldsymbol K_1, \ldots, \boldsymbol K_D`,
1716
compute the sum of each covariance matrix as well as the interaction terms up to degree `max_degree`

gpytorch/variational/nearest_neighbor_variational_strategy.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Optional
44

55
import torch
6-
from jaxtyping import Float
76
from linear_operator import to_dense
87
from linear_operator.operators import DiagLinearOperator, LinearOperator, TriangularLinearOperator
98
from linear_operator.utils.cholesky import psd_safe_cholesky
@@ -77,8 +76,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy):
7776
def __init__(
7877
self,
7978
model: ApproximateGP,
80-
inducing_points: Float[Tensor, "... M D"],
81-
variational_distribution: Float[_VariationalDistribution, "... M"],
79+
inducing_points: Tensor, # shape: (..., M, D)
80+
variational_distribution: _VariationalDistribution, # shape: (..., M)
8281
k: int,
8382
training_batch_size: Optional[int] = None,
8483
jitter_val: Optional[float] = 1e-3,
@@ -120,21 +119,26 @@ def __init__(
120119

121120
@property
122121
@cached(name="prior_distribution_memo")
123-
def prior_distribution(self) -> Float[MultivariateNormal, "... M"]:
122+
def prior_distribution(self) -> MultivariateNormal: # shape: (..., M)
124123
out = self.model.forward(self.inducing_points)
125124
res = MultivariateNormal(out.mean, out.lazy_covariance_matrix.add_jitter(self.jitter_val))
126125
return res
127126

128127
def _cholesky_factor(
129-
self, induc_induc_covar: Float[LinearOperator, "... M M"]
130-
) -> Float[TriangularLinearOperator, "... M M"]:
128+
self,
129+
induc_induc_covar: LinearOperator, # shape: (..., M, M)
130+
) -> TriangularLinearOperator: # shape: (..., M, M)
131131
# Uncached version
132132
L = psd_safe_cholesky(to_dense(induc_induc_covar))
133133
return TriangularLinearOperator(L)
134134

135135
def __call__(
136-
self, x: Float[Tensor, "... N D"], prior: bool = False, diag: bool = True, **kwargs: Any
137-
) -> Float[MultivariateNormal, "... N"]:
136+
self,
137+
x: Tensor, # shape: (..., N, D)
138+
prior: bool = False,
139+
diag: bool = True,
140+
**kwargs: Any,
141+
) -> MultivariateNormal: # shape: (..., N)
138142
# If we're in prior mode, then we're done!
139143
if prior:
140144
return self.model.forward(x, **kwargs)
@@ -176,13 +180,13 @@ def __call__(
176180

177181
def forward(
178182
self,
179-
x: Float[Tensor, "... N D"],
180-
inducing_points: Float[Tensor, "... M D"],
181-
inducing_values: Float[Tensor, "... M"],
182-
variational_inducing_covar: Optional[Float[LinearOperator, "... M M"]] = None,
183+
x: Tensor, # shape: (..., N, D)
184+
inducing_points: Tensor, # shape: (..., M, D)
185+
inducing_values: Tensor, # shape: (..., M)
186+
variational_inducing_covar: Optional[LinearOperator] = None, # shape: (..., M, M)
183187
diag: bool = True,
184188
**kwargs: Any,
185-
) -> Float[MultivariateNormal, "... N"]:
189+
) -> MultivariateNormal: # shape: (..., N)
186190
# TODO: This method needs to return the full covariance in eval mode, not just the predictive variance.
187191
# TODO: Use `diag` to control when to compute the variance vs. covariance in train mode.
188192
if self.training:
@@ -281,8 +285,8 @@ def forward(
281285

282286
def get_fantasy_model(
283287
self,
284-
inputs: Float[Tensor, "... N D"],
285-
targets: Float[Tensor, "... N"],
288+
inputs: Tensor, # shape: (..., N, D)
289+
targets: Tensor, # shape: (..., N)
286290
mean_module: Optional[Module] = None,
287291
covar_module: Optional[Module] = None,
288292
**kwargs,
@@ -312,7 +316,7 @@ def _get_training_indices(self) -> LongTensor:
312316
self._set_training_iterator()
313317
return self.current_training_indices
314318

315-
def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
319+
def _firstk_kl_helper(self) -> Tensor: # shape: (...)
316320
# Compute the KL divergence for first k inducing points
317321
train_x_firstk = self.inducing_points[..., : self.k, :]
318322
full_output = self.model.forward(train_x_firstk)
@@ -330,7 +334,10 @@ def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
330334
kl = torch.distributions.kl.kl_divergence(variational_distribution, prior_dist) # model_batch_shape
331335
return kl
332336

333-
def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[Tensor, "..."]: # noqa: F821
337+
def _stochastic_kl_helper(
338+
self,
339+
kl_indices: Tensor, # shape: (n_batch,)
340+
) -> Tensor: # shape: (...)
334341
# Compute the KL divergence for a mini batch of the rest M-k inducing points
335342
# See paper appendix for kl breakdown
336343
kl_bs = len(kl_indices) # training_batch_size
@@ -435,7 +442,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T
435442

436443
def _kl_divergence(
437444
self, kl_indices: Optional[LongTensor] = None, batch_size: Optional[int] = None
438-
) -> Float[Tensor, "..."]:
445+
) -> Tensor: # shape: (...)
439446
if self.compute_full_kl or (self._total_training_batches == 1):
440447
if batch_size is None:
441448
batch_size = self.training_batch_size
@@ -455,7 +462,7 @@ def _kl_divergence(
455462
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
456463
return kl
457464

458-
def kl_divergence(self) -> Float[Tensor, "..."]:
465+
def kl_divergence(self) -> Tensor: # shape: (...)
459466
try:
460467
return pop_from_cache(self, "kl_divergence_memo")
461468
except CachingError:

0 commit comments

Comments
 (0)