44
55import warnings
66from abc import abstractmethod
7+ from collections import defaultdict , OrderedDict
78from copy import deepcopy
8- from typing import Callable , Dict , Iterable , Optional , Tuple , Union
9+ from typing import Any , Callable , Dict , Iterable , Optional , Tuple , Union
910
1011import torch
1112from linear_operator import to_dense , to_linear_operator
12- from linear_operator .operators import LinearOperator , ZeroLinearOperator
13+ from linear_operator .operators import KernelLinearOperator , LinearOperator , ZeroLinearOperator
1314from torch import Tensor
1415from torch .nn import ModuleList
1516
@@ -75,6 +76,45 @@ def _dist(self, x1, x2, x1_eq_x2=False, postprocess=False):
7576 return self ._postprocess (res ) if postprocess else res
7677
7778
79+ class _autograd_kernel_hack (object ):
80+ """
81+ Helper class.
82+
83+ When using KernelLinearOperator, the `covar_func` cannot close over any Tensors that require gradients.
84+ (Any Tensor that `covar_func` closes over will not backpropagate gradients.)
85+ Unfortunately, for most kernels, `covar_func=self.forward`, which closes over all of the kernel's parameters.
86+
87+ This context manager temporarily replaces a kernel (and its submodules') parameter assignments with an
88+ external set of references to these parameters.
89+ The external set of references will be passed in by KernelLinearOperator.
90+
91+ This way, when calling self.forward, no parameter references are closed over, and so all parameters
92+ will receive the appropriate gradients.
93+ """
94+
95+ def __init__ (self , kernel : Kernel , params : Iterable [torch .nn .Parameters ], param_names : Iterable [str ]):
96+ self .temp_module_param_dicts = defaultdict (OrderedDict )
97+ for name , param in zip (param_names , params ):
98+ split_name = name .split ("." )
99+ module = kernel
100+ while len (split_name ) > 1 :
101+ module_name , * remaining_names = split_name
102+ module = getattr (module , module_name )
103+ split_name = remaining_names
104+ (base_param_name ,) = split_name
105+ self .temp_module_param_dicts [module ][base_param_name ] = param
106+
107+ self .orig_model_param_dicts = dict ((module , module ._parameters ) for module in self .temp_module_param_dicts )
108+
109+ def __enter__ (self ):
110+ for module , temp_param_dict in self .temp_module_param_dicts .items ():
111+ object .__setattr__ (module , "_parameters" , temp_param_dict )
112+
113+ def __exit__ (self , type , value , traceback ):
114+ for module , orig_param_dict in self .orig_model_param_dicts .items ():
115+ object .__setattr__ (module , "_parameters" , orig_param_dict )
116+
117+
78118class Kernel (Module ):
79119 r"""
80120 Kernels in GPyTorch are implemented as a :class:`gpytorch.Module` that, when called on two :class:`torch.Tensor`
@@ -206,6 +246,37 @@ def __init__(
206246 # TODO: Remove this on next official PyTorch release.
207247 self .__pdist_supports_batch = True
208248
249+ @property
250+ def _lazily_evaluate (self ) -> bool :
251+ r"""
252+ Determines whether or not the kernel is lazily evaluated.
253+
254+ If False, kernel(x1, x2) produces a Tensor/LinearOperator where the covariance function has been evaluated
255+ over x1 and x2.
256+
257+ If True, kernel(x1, x2) produces a KernelLinearOperator that delays evaluation of the kernel function.
258+ The kernel function will only be evaluated when either
259+ - An mathematical operation is performed on the kernel matrix (e.g. solves, logdets, etc.), or
260+ - An indexing operation is performed on the kernel matrix to select specific covariance entries.
261+
262+ In general, _lazily_evaluate should return True (this option is more efficient), unless lazy evaluation
263+ offers no gains and there is specific structure that will be lost with lazy evaluation
264+ (e.g. low-rank/Nystrom approximations).
265+ """
266+ return True
267+
268+ def _kernel_linear_operator_covar_func (
269+ self , x1 : Tensor , x2 : Tensor , * params : torch .nn .Parameter , param_names : Dict [str ] = {}, ** kwargs : Any
270+ ) -> Union [Tensor , LinearOperator ]:
271+ # This is the `covar_function` that is passed into KernelLinearOperator
272+ # This function calls self.forward, but does so in a way so that no parameters are closed over
273+ # (by using the _autograd_kernel_hack context manager)
274+ if any (param .requires_grad for param in params ):
275+ with _autograd_kernel_hack (self , params , param_names ):
276+ return self .forward (x1 , x2 , ** kwargs )
277+ else :
278+ return self .forward (x1 , x2 , ** kwargs )
279+
209280 def _lengthscale_param (self , m : Kernel ) -> Tensor :
210281 # Used by the lengthscale_prior
211282 return m .lengthscale
@@ -451,7 +522,7 @@ def sub_kernels(self) -> Iterable[Kernel]:
451522 yield kernel
452523
453524 def __call__ (
454- self , x1 : Tensor , x2 : Optional [Tensor ] = None , diag : bool = False , last_dim_is_batch : bool = False , ** params
525+ self , x1 : Tensor , x2 : Optional [Tensor ] = None , diag : bool = False , last_dim_is_batch : bool = False , ** kwargs
455526 ) -> Union [LazyEvaluatedKernelTensor , LinearOperator , Tensor ]:
456527 r"""
457528 Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
@@ -508,7 +579,7 @@ def __call__(
508579 )
509580
510581 if diag :
511- res = super (Kernel , self ).__call__ (x1_ , x2_ , diag = True , last_dim_is_batch = last_dim_is_batch , ** params )
582+ res = super (Kernel , self ).__call__ (x1_ , x2_ , diag = True , last_dim_is_batch = last_dim_is_batch , ** kwargs )
512583 # Did this Kernel eat the diag option?
513584 # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
514585 if not isinstance (res , LazyEvaluatedKernelTensor ):
@@ -517,11 +588,42 @@ def __call__(
517588 return res
518589
519590 else :
520- if settings .lazily_evaluate_kernels .on ():
521- res = LazyEvaluatedKernelTensor (x1_ , x2_ , kernel = self , last_dim_is_batch = last_dim_is_batch , ** params )
591+ if (settings .lazily_evaluate_kernels .on () and self ._lazily_evaluate ) or last_dim_is_batch :
592+ num_outputs_per_input = self .num_outputs_per_input (x1_ , x2_ )
593+ named_parameters = tuple (self .named_parameters ())
594+
595+ if last_dim_is_batch :
596+ x1_ = x1_ .transpose (- 1 , - 2 ).unsqueeze (- 1 )
597+ x2_ = x2_ .transpose (- 1 , - 2 ).unsqueeze (- 1 )
598+
599+ if len (named_parameters ):
600+ param_names , params = zip (* named_parameters )
601+ param_batch_shapes = [self .batch_shape ] * len (params )
602+ if last_dim_is_batch :
603+ params = [
604+ param .unsqueeze (len (param_batch_shape )).transpose (- 1 , len (param_batch_shape ))
605+ for param , param_batch_shape in zip (params , param_batch_shapes )
606+ ]
607+ param_batch_shapes = [
608+ torch .Size ([* param_batch_shape , x1_ .size (- 3 )]) for param_batch_shape in param_batch_shapes
609+ ]
610+ res = KernelLinearOperator (
611+ x1_ ,
612+ x2_ ,
613+ * params ,
614+ covar_func = self ._kernel_linear_operator_covar_func ,
615+ num_outputs_per_input = num_outputs_per_input ,
616+ param_batch_shapes = param_batch_shapes ,
617+ param_names = param_names ,
618+ ** kwargs ,
619+ )
620+ else :
621+ res = KernelLinearOperator (
622+ x1_ , x2_ , covar_func = self .forward , num_outputs_per_input = num_outputs_per_input , ** kwargs
623+ )
522624 else :
523625 res = to_linear_operator (
524- super (Kernel , self ).__call__ (x1_ , x2_ , last_dim_is_batch = last_dim_is_batch , ** params )
626+ super (Kernel , self ).__call__ (x1_ , x2_ , last_dim_is_batch = last_dim_is_batch , ** kwargs )
525627 )
526628 return res
527629
@@ -593,13 +695,17 @@ class AdditiveKernel(Kernel):
593695 :param kernels: Kernels to add together.
594696 """
595697
698+ def __init__ (self , * kernels : Iterable [Kernel ]):
699+ super (AdditiveKernel , self ).__init__ ()
700+ self .kernels = ModuleList (kernels )
701+
596702 @property
597703 def is_stationary (self ) -> bool :
598704 return all (k .is_stationary for k in self .kernels )
599705
600- def __init__ ( self , * kernels : Iterable [ Kernel ]):
601- super ( AdditiveKernel , self ). __init__ ()
602- self . kernels = ModuleList ( kernels )
706+ @ property
707+ def _lazily_evaluate ( self ) -> bool :
708+ return all ( k . _lazily_evaluate for k in self . kernels )
603709
604710 def forward (self , x1 : Tensor , x2 : Tensor , diag : bool = False , ** params ) -> Union [Tensor , LinearOperator ]:
605711 res = ZeroLinearOperator () if not diag else 0
@@ -635,13 +741,17 @@ class ProductKernel(Kernel):
635741 :param kernels: Kernels to multiply together.
636742 """
637743
744+ def __init__ (self , * kernels : Iterable [Kernel ]):
745+ super (ProductKernel , self ).__init__ ()
746+ self .kernels = ModuleList (kernels )
747+
638748 @property
639749 def is_stationary (self ) -> bool :
640750 return all (k .is_stationary for k in self .kernels )
641751
642- def __init__ ( self , * kernels : Iterable [ Kernel ]):
643- super ( ProductKernel , self ). __init__ ()
644- self . kernels = ModuleList ( kernels )
752+ @ property
753+ def _lazily_evaluate ( self ) -> bool :
754+ return False
645755
646756 def forward (self , x1 : Tensor , x2 : Tensor , diag : bool = False , ** params ) -> Union [Tensor , LinearOperator ]:
647757 x1_eq_x2 = torch .equal (x1 , x2 )
0 commit comments