33from typing import Any , Optional
44
55import torch
6- from jaxtyping import Float
76from linear_operator import to_dense
87from linear_operator .operators import DiagLinearOperator , LinearOperator , TriangularLinearOperator
98from linear_operator .utils .cholesky import psd_safe_cholesky
@@ -77,8 +76,8 @@ class NNVariationalStrategy(UnwhitenedVariationalStrategy):
7776 def __init__ (
7877 self ,
7978 model : ApproximateGP ,
80- inducing_points : Float [ Tensor , " ... M D" ],
81- variational_distribution : Float [ _VariationalDistribution , " ... M" ],
79+ inducing_points : Tensor , # shape: ( ..., M, D)
80+ variational_distribution : _VariationalDistribution , # shape: ( ..., M)
8281 k : int ,
8382 training_batch_size : Optional [int ] = None ,
8483 jitter_val : Optional [float ] = 1e-3 ,
@@ -120,21 +119,26 @@ def __init__(
120119
121120 @property
122121 @cached (name = "prior_distribution_memo" )
123- def prior_distribution (self ) -> Float [ MultivariateNormal , " ... M" ]:
122+ def prior_distribution (self ) -> MultivariateNormal : # shape: ( ..., M)
124123 out = self .model .forward (self .inducing_points )
125124 res = MultivariateNormal (out .mean , out .lazy_covariance_matrix .add_jitter (self .jitter_val ))
126125 return res
127126
128127 def _cholesky_factor (
129- self , induc_induc_covar : Float [LinearOperator , "... M M" ]
130- ) -> Float [TriangularLinearOperator , "... M M" ]:
128+ self ,
129+ induc_induc_covar : LinearOperator , # shape: (..., M, M)
130+ ) -> TriangularLinearOperator : # shape: (..., M, M)
131131 # Uncached version
132132 L = psd_safe_cholesky (to_dense (induc_induc_covar ))
133133 return TriangularLinearOperator (L )
134134
135135 def __call__ (
136- self , x : Float [Tensor , "... N D" ], prior : bool = False , diag : bool = True , ** kwargs : Any
137- ) -> Float [MultivariateNormal , "... N" ]:
136+ self ,
137+ x : Tensor , # shape: (..., N, D)
138+ prior : bool = False ,
139+ diag : bool = True ,
140+ ** kwargs : Any ,
141+ ) -> MultivariateNormal : # shape: (..., N)
138142 # If we're in prior mode, then we're done!
139143 if prior :
140144 return self .model .forward (x , ** kwargs )
@@ -176,13 +180,13 @@ def __call__(
176180
177181 def forward (
178182 self ,
179- x : Float [ Tensor , " ... N D" ],
180- inducing_points : Float [ Tensor , " ... M D" ],
181- inducing_values : Float [ Tensor , " ... M" ],
182- variational_inducing_covar : Optional [Float [ LinearOperator , " ... M M" ]] = None ,
183+ x : Tensor , # shape: ( ..., N, D)
184+ inducing_points : Tensor , # shape: ( ..., M, D)
185+ inducing_values : Tensor , # shape: ( ..., M)
186+ variational_inducing_covar : Optional [LinearOperator ] = None , # shape: ( ..., M, M)
183187 diag : bool = True ,
184188 ** kwargs : Any ,
185- ) -> Float [ MultivariateNormal , " ... N" ]:
189+ ) -> MultivariateNormal : # shape: ( ..., N)
186190 # TODO: This method needs to return the full covariance in eval mode, not just the predictive variance.
187191 # TODO: Use `diag` to control when to compute the variance vs. covariance in train mode.
188192 if self .training :
@@ -281,8 +285,8 @@ def forward(
281285
282286 def get_fantasy_model (
283287 self ,
284- inputs : Float [ Tensor , " ... N D" ],
285- targets : Float [ Tensor , " ... N" ],
288+ inputs : Tensor , # shape: ( ..., N, D)
289+ targets : Tensor , # shape: ( ..., N)
286290 mean_module : Optional [Module ] = None ,
287291 covar_module : Optional [Module ] = None ,
288292 ** kwargs ,
@@ -312,7 +316,7 @@ def _get_training_indices(self) -> LongTensor:
312316 self ._set_training_iterator ()
313317 return self .current_training_indices
314318
315- def _firstk_kl_helper (self ) -> Float [ Tensor , " ..." ]:
319+ def _firstk_kl_helper (self ) -> Tensor : # shape: ( ...)
316320 # Compute the KL divergence for first k inducing points
317321 train_x_firstk = self .inducing_points [..., : self .k , :]
318322 full_output = self .model .forward (train_x_firstk )
@@ -330,7 +334,10 @@ def _firstk_kl_helper(self) -> Float[Tensor, "..."]:
330334 kl = torch .distributions .kl .kl_divergence (variational_distribution , prior_dist ) # model_batch_shape
331335 return kl
332336
333- def _stochastic_kl_helper (self , kl_indices : Float [Tensor , "n_batch" ]) -> Float [Tensor , "..." ]: # noqa: F821
337+ def _stochastic_kl_helper (
338+ self ,
339+ kl_indices : Tensor , # shape: (n_batch,)
340+ ) -> Tensor : # shape: (...)
334341 # Compute the KL divergence for a mini batch of the rest M-k inducing points
335342 # See paper appendix for kl breakdown
336343 kl_bs = len (kl_indices ) # training_batch_size
@@ -435,7 +442,7 @@ def _stochastic_kl_helper(self, kl_indices: Float[Tensor, "n_batch"]) -> Float[T
435442
436443 def _kl_divergence (
437444 self , kl_indices : Optional [LongTensor ] = None , batch_size : Optional [int ] = None
438- ) -> Float [ Tensor , " ..." ]:
445+ ) -> Tensor : # shape: ( ...)
439446 if self .compute_full_kl or (self ._total_training_batches == 1 ):
440447 if batch_size is None :
441448 batch_size = self .training_batch_size
@@ -455,7 +462,7 @@ def _kl_divergence(
455462 kl = self ._stochastic_kl_helper (kl_indices ) * self .M / len (kl_indices )
456463 return kl
457464
458- def kl_divergence (self ) -> Float [ Tensor , " ..." ]:
465+ def kl_divergence (self ) -> Tensor : # shape: ( ...)
459466 try :
460467 return pop_from_cache (self , "kl_divergence_memo" )
461468 except CachingError :
0 commit comments