22
33import  torch 
44
5- from  ..distributions  import  MultitaskMultivariateNormal 
6- from  ..lazy  import  KroneckerProductLazyTensor , MatmulLazyTensor 
5+ from  .. import  settings 
6+ from  ..distributions  import  MultitaskMultivariateNormal , MultivariateNormal 
7+ from  ..lazy  import  KroneckerProductLazyTensor , RootLazyTensor 
78from  ..module  import  Module 
9+ from  ..utils .broadcasting  import  _mul_broadcast_shape 
10+ from  ..utils .interpolation  import  left_interp 
811from  ._variational_strategy  import  _VariationalStrategy 
912
1013
14+ def  _select_lmc_coefficients (lmc_coefficients : torch .Tensor , indices : torch .LongTensor ) ->  torch .Tensor :
15+     """ 
16+     Given a list of indices for ... x N datapoints, 
17+       select the row from lmc_coefficient that corresponds to each datapoint 
18+ 
19+     lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks 
20+     indices: torch.Tesnor ... x N 
21+     """ 
22+     batch_shape  =  _mul_broadcast_shape (lmc_coefficients .shape [:- 1 ], indices .shape [:- 1 ])
23+ 
24+     # We will use the left_interp helper to do the indexing 
25+     lmc_coefficients  =  lmc_coefficients .expand (* batch_shape , lmc_coefficients .shape [- 1 ])[..., None ]
26+     indices  =  indices .expand (* batch_shape , indices .shape [- 1 ])[..., None ]
27+     res  =  left_interp (
28+         indices , torch .ones (indices .shape , dtype = torch .long , device = indices .device ), lmc_coefficients ,
29+     ).squeeze (- 1 )
30+     return  res 
31+ 
32+ 
1133class  LMCVariationalStrategy (_VariationalStrategy ):
1234    r""" 
1335    LMCVariationalStrategy is an implementation of the "Linear Model of Coregionalization" 
@@ -20,8 +42,11 @@ class LMCVariationalStrategy(_VariationalStrategy):
2042
2143        f_{\text{task } i}( \mathbf x) = \sum_{q=1}^Q a_i^{(q)} g^{(q)} ( \mathbf x ) 
2244
23-     LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy` 
24-     to produce a :obj:`~gpytorch.variational.MultitaskMultivariateNormal` distribution. 
45+     LMCVariationalStrategy wraps an existing :obj:`~gpytorch.variational.VariationalStrategy`. 
46+     The output will either be a :obj:`~gpytorch.distributions.MultitaskMultivariateNormal` distribution 
47+     (if we wish to evaluate all tasks for each input) or a :obj:`~gpytorch.distributions.MultivariateNormal` 
48+     (if we wish to evaluate a single task for each input). 
49+ 
2550    The base variational strategy is assumed to operate on a multi-batch of GPs, where one 
2651    of the batch dimensions corresponds to the latent function dimension. 
2752
@@ -35,13 +60,6 @@ class LMCVariationalStrategy(_VariationalStrategy):
3560        batch shape. This would correspond to each of the latent functions having different kernels 
3661        or the same kernel, respectivly. 
3762
38-     :param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy 
39-     :param int num_tasks: The total number of tasks (output functions) 
40-     :param int num_latents: The total number of latent functions in each group 
41-     :param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch. 
42-         **Must be negative indexed** 
43-     :type latent_dim: `int` < 0 
44- 
4563    Example: 
4664        >>> class LMCMultitaskGP(gpytorch.models.ApproximateGP): 
4765        >>>     ''' 
@@ -74,7 +92,13 @@ class LMCVariationalStrategy(_VariationalStrategy):
7492        >>>             batch_shape=torch.Size([3]), 
7593        >>>         ) 
7694        >>> 
77-         >>> # Model output: n x 5 
95+ 
96+     :param ~gpytorch.variational.VariationalStrategy base_variational_strategy: Base variational strategy 
97+     :param int num_tasks: The total number of tasks (output functions) 
98+     :param int num_latents: The total number of latent functions in each group 
99+     :param latent_dim: (Default: -1) Which batch dimension corresponds to the latent function batch. 
100+         **Must be negative indexed** 
101+     :type latent_dim: `int` < 0 
78102    """ 
79103
80104    def  __init__ (
@@ -120,28 +144,84 @@ def variational_params_initialized(self):
120144    def  kl_divergence (self ):
121145        return  super ().kl_divergence ().sum (dim = self .latent_dim )
122146
123-     def  __call__ (self , x , prior = False , ** kwargs ):
124-         function_dist  =  self .base_variational_strategy (x , prior = prior , ** kwargs )
125-         lmc_coefficients  =  self .lmc_coefficients .expand (* function_dist .batch_shape , self .lmc_coefficients .size (- 1 ))
126-         num_batch  =  len (function_dist .batch_shape )
127-         num_dim  =  num_batch  +  len (function_dist .event_shape )
128-         latent_dim  =  num_batch  +  self .latent_dim  if  self .latent_dim  is  not   None  else  None 
129- 
130-         # Mean 
131-         mean  =  function_dist .mean .permute (* range (0 , latent_dim ), * range (latent_dim  +  1 , num_dim ), latent_dim )
132-         mean  =  mean  @ lmc_coefficients .permute (
133-             * range (0 , latent_dim ), * range (latent_dim  +  1 , num_dim  -  1 ), latent_dim , - 1 
134-         )
135- 
136-         # Covar 
137-         covar  =  function_dist .lazy_covariance_matrix 
138-         lmc_factor  =  MatmulLazyTensor (lmc_coefficients .unsqueeze (- 1 ), lmc_coefficients .unsqueeze (- 2 ))
139-         covar  =  KroneckerProductLazyTensor (covar , lmc_factor )
140-         covar  =  covar .sum (latent_dim )
141- 
142-         # Add a bit of jitter to make the covar PD 
143-         covar  =  covar .add_jitter (1e-6 )
144- 
145-         # Done! 
146-         function_dist  =  MultitaskMultivariateNormal (mean , covar )
147+     def  __call__ (self , x , task_indices = None , prior = False , ** kwargs ):
148+         r""" 
149+         Computes the variational (or prior) distribution 
150+         :math:`q( \mathbf f \mid \mathbf X)` (or :math:`p( \mathbf f \mid \mathbf X)`). 
151+         There are two modes: 
152+ 
153+         1.  Compute **all tasks** for all inputs. 
154+             If this is the case, the :attr:`task_indices` attribute should be None. 
155+             The return type will be a (... x N x num_tasks) 
156+             :class:`~gpytorch.distributions.MultitaskMultivariateNormal`. 
157+         2.  Compute **one task** per inputs. 
158+             If this is the case, the (... x N) :attr:`task_indices` tensor should contain 
159+             the indices of each input's assigned task. 
160+             The return type will be a (... x N) 
161+             :class:`~gpytorch.distributions.MultivariateNormal`. 
162+ 
163+         :param x: Input locations to evaluate variational strategy 
164+         :type x: torch.Tensor (... x N x D) 
165+         :param task_indices: (Default: None) Task index associated with each input. 
166+             If this **is not** provided, then the returned distribution evaluates every input on every task 
167+             (returns :class:`~gpytorch.distributions.MultitaskMultivariateNormal`). 
168+             If this **is** provided, then the returned distribution evaluates each input only on its assigned task. 
169+             (returns :class:`~gpytorch.distributions.MultivariateNormal`). 
170+         :type task_indices: torch.Tensor (... x N), optional 
171+         :param prior: (Default: False) If False, returns the variational distribution 
172+             :math:`q( \mathbf f \mid \mathbf X)`. 
173+             If True, returns the prior distribution 
174+             :math:`p( \mathbf f \mid \mathbf X)`. 
175+         :type prior: bool 
176+         :return: :math:`q( \mathbf f \mid \mathbf X)` (or the prior), 
177+             either for all tasks (if `task_indices == None`) 
178+             or for a specific task (if `task_indices != None`). 
179+         :rtype: ~gpytorch.distributions.MultitaskMultivariateNormal (... x N x num_tasks) 
180+             or ~gpytorch.distributions.MultivariateNormal (... x N) 
181+         """ 
182+         latent_dist  =  self .base_variational_strategy (x , prior = prior , ** kwargs )
183+         num_batch  =  len (latent_dist .batch_shape )
184+         latent_dim  =  num_batch  +  self .latent_dim 
185+ 
186+         if  task_indices  is  None :
187+             num_dim  =  num_batch  +  len (latent_dist .event_shape )
188+ 
189+             # Every data point will get an output for each task 
190+             # Therefore, we will set up the lmc_coefficients shape for a matmul 
191+             lmc_coefficients  =  self .lmc_coefficients .expand (* latent_dist .batch_shape , self .lmc_coefficients .size (- 1 ))
192+ 
193+             # Mean: ... x N x num_tasks 
194+             latent_mean  =  latent_dist .mean .permute (* range (0 , latent_dim ), * range (latent_dim  +  1 , num_dim ), latent_dim )
195+             mean  =  latent_mean  @ lmc_coefficients .permute (
196+                 * range (0 , latent_dim ), * range (latent_dim  +  1 , num_dim  -  1 ), latent_dim , - 1 
197+             )
198+ 
199+             # Covar: ... x (N x num_tasks) x (N x num_tasks) 
200+             latent_covar  =  latent_dist .lazy_covariance_matrix 
201+             lmc_factor  =  RootLazyTensor (lmc_coefficients .unsqueeze (- 1 ))
202+             covar  =  KroneckerProductLazyTensor (latent_covar , lmc_factor ).sum (latent_dim )
203+             # Add a bit of jitter to make the covar PD 
204+             covar  =  covar .add_jitter (settings .cholesky_jitter .value (dtype = mean .dtype ))
205+ 
206+             # Done! 
207+             function_dist  =  MultitaskMultivariateNormal (mean , covar )
208+ 
209+         else :
210+             # Each data point will get a single output corresponding to a single task 
211+             # Therefore, we will select the appropriate lmc coefficients for each task 
212+             lmc_coefficients  =  _select_lmc_coefficients (self .lmc_coefficients , task_indices )
213+ 
214+             # Mean: ... x N 
215+             mean  =  (latent_dist .mean  *  lmc_coefficients ).sum (latent_dim )
216+ 
217+             # Covar: ... x N x N 
218+             latent_covar  =  latent_dist .lazy_covariance_matrix 
219+             lmc_factor  =  RootLazyTensor (lmc_coefficients .unsqueeze (- 1 ))
220+             covar  =  (latent_covar  *  lmc_factor ).sum (latent_dim )
221+             # Add a bit of jitter to make the covar PD 
222+             covar  =  covar .add_jitter (settings .cholesky_jitter .value (dtype = mean .dtype ))
223+ 
224+             # Done! 
225+             function_dist  =  MultivariateNormal (mean , covar )
226+ 
147227        return  function_dist 
0 commit comments