11#!/usr/bin/env python3
22
3+ from typing import Any , Optional
4+
5+ from torch import Tensor
6+
7+ from ..distributions import MultivariateNormal
8+ from .exact_gp import ExactGP
9+
310from .gp import GP
411from .pyro import _PyroMixin # This will only contain functions if Pyro is installed
512
@@ -44,38 +51,38 @@ class ApproximateGP(GP, _PyroMixin):
4451
4552 def __init__ (self , variational_strategy ):
4653 super ().__init__ ()
54+
4755 self .variational_strategy = variational_strategy
4856
49- def forward (self , x ):
57+ def forward (self , x : Tensor ):
5058 raise NotImplementedError
5159
52- def pyro_guide (self , input , beta = 1.0 , name_prefix = "" ):
60+ def pyro_guide (self , input : Tensor , beta : float = 1.0 , name_prefix : str = "" ):
5361 r"""
5462 (For Pyro integration only). The component of a `pyro.guide` that
5563 corresponds to drawing samples from the latent GP function.
5664
57- :param torch.Tensor input: The inputs :math:`\mathbf X`.
58- :param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
65+ :param input: The inputs :math:`\mathbf X`.
66+ :param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
5967 term by.
60- :param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
68+ :param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
6169 """
6270 return super ().pyro_guide (input , beta = beta , name_prefix = name_prefix )
6371
64- def pyro_model (self , input , beta = 1.0 , name_prefix = "" ):
72+ def pyro_model (self , input : Tensor , beta : float = 1.0 , name_prefix : str = "" ) -> Tensor :
6573 r"""
6674 (For Pyro integration only). The component of a `pyro.model` that
6775 corresponds to drawing samples from the latent GP function.
6876
69- :param torch.Tensor input: The inputs :math:`\mathbf X`.
70- :param float beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
77+ :param input: The inputs :math:`\mathbf X`.
78+ :param beta: (default=1.) How much to scale the :math:`\text{KL} [ q(\mathbf f) \Vert p(\mathbf f) ]`
7179 term by.
72- :param str name_prefix: (default="") A name prefix to prepend to pyro sample sites.
80+ :param name_prefix: (default="") A name prefix to prepend to pyro sample sites.
7381 :return: samples from :math:`q(\mathbf f)`
74- :rtype: torch.Tensor
7582 """
7683 return super ().pyro_model (input , beta = beta , name_prefix = name_prefix )
7784
78- def get_fantasy_model (self , inputs , targets , ** kwargs ) :
85+ def get_fantasy_model (self , inputs : Tensor , targets : Tensor , ** kwargs : Any ) -> ExactGP :
7986 r"""
8087 Returns a new GP model that incorporates the specified inputs and targets as new training data using
8188 online variational conditioning (OVC).
@@ -88,12 +95,11 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
8895 If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
8996 are the same for each target batch.
9097
91- :param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
98+ :param inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
9299 observations.
93- :param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
100+ :param targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
94101 :return: An `ExactGP` model with `n + m` training examples, where the `m` fantasy examples have been added
95102 and all test-time caches have been updated.
96- :rtype: ~gpytorch.models.ExactGP
97103
98104 Reference: "Conditioning Sparse Variational Gaussian Processes for Online Decision-Making,"
99105 Maddox, Stanton, Wilson, NeurIPS, '21
@@ -102,7 +108,7 @@ def get_fantasy_model(self, inputs, targets, **kwargs):
102108 """
103109 return self .variational_strategy .get_fantasy_model (inputs = inputs , targets = targets , ** kwargs )
104110
105- def __call__ (self , inputs , prior = False , ** kwargs ):
106- if inputs .dim () == 1 :
111+ def __call__ (self , inputs : Optional [ Tensor ] , prior : bool = False , ** kwargs ) -> MultivariateNormal :
112+ if inputs is not None and inputs .dim () == 1 :
107113 inputs = inputs .unsqueeze (- 1 )
108114 return self .variational_strategy (inputs , prior = prior , ** kwargs )
0 commit comments