@@ -231,9 +231,7 @@ def _set_lengthscale(self, value: Tensor):
231231 self .initialize (raw_lengthscale = self .raw_lengthscale_constraint .inverse_transform (value ))
232232
233233 @abstractmethod
234- def forward (
235- self , x1 : Tensor , x2 : Tensor , diag : bool = False , last_dim_is_batch : bool = False , ** params
236- ) -> Union [Tensor , LinearOperator ]:
234+ def forward (self , x1 : Tensor , x2 : Tensor , diag : bool = False , ** params ) -> Union [Tensor , LinearOperator ]:
237235 r"""
238236 Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
239237 This method should be implemented by all Kernel subclasses.
@@ -242,16 +240,11 @@ def forward(
242240 :param x2: Second set of data (... x M x D).
243241 :param diag: Should the Kernel compute the whole kernel, or just the diag?
244242 If True, it must be the case that `x1 == x2`. (Default: False.)
245- :param last_dim_is_batch: If True, treat the last dimension
246- of `x1` and `x2` as another batch dimension.
247- (Useful for additive structure over the dimensions). (Default: False.)
248243
249244 :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
250245
251246 * `full_covar`: `... x N x M`
252- * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
253247 * `diag`: `... x N`
254- * `diag` with `last_dim_is_batch=True`: `... x K x N`
255248 """
256249 raise NotImplementedError ()
257250
@@ -314,7 +307,6 @@ def covar_dist(
314307 x1 : Tensor ,
315308 x2 : Tensor ,
316309 diag : bool = False ,
317- last_dim_is_batch : bool = False ,
318310 square_dist : bool = False ,
319311 ** params ,
320312 ) -> Tensor :
@@ -326,22 +318,13 @@ def covar_dist(
326318 :param x2: Second set of data (... x M x D).
327319 :param diag: Should the Kernel compute the whole kernel, or just the diag?
328320 If True, it must be the case that `x1 == x2`. (Default: False.)
329- :param last_dim_is_batch: If True, treat the last dimension
330- of `x1` and `x2` as another batch dimension.
331- (Useful for additive structure over the dimensions). (Default: False.)
332321 :param square_dist:
333322 If True, returns the squared distance rather than the standard distance. (Default: False.)
334323 :return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
335324
336325 * `full_covar`: `... x N x M`
337- * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
338326 * `diag`: `... x N`
339- * `diag` with `last_dim_is_batch=True`: `... x K x N`
340327 """
341- if last_dim_is_batch :
342- x1 = x1 .transpose (- 1 , - 2 ).unsqueeze (- 1 )
343- x2 = x2 .transpose (- 1 , - 2 ).unsqueeze (- 1 )
344-
345328 x1_eq_x2 = torch .equal (x1 , x2 )
346329 res = None
347330
@@ -457,7 +440,7 @@ def sub_kernels(self) -> Iterable[Kernel]:
457440 yield kernel
458441
459442 def __call__ (
460- self , x1 : Tensor , x2 : Optional [Tensor ] = None , diag : bool = False , last_dim_is_batch : bool = False , ** params
443+ self , x1 : Tensor , x2 : Optional [Tensor ] = None , diag : bool = False , ** params
461444 ) -> Union [LazyEvaluatedKernelTensor , LinearOperator , Tensor ]:
462445 r"""
463446 Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
@@ -473,27 +456,13 @@ def __call__(
473456 (If `None`, then `x2` is set to `x1`.)
474457 :param diag: Should the Kernel compute the whole kernel, or just the diag?
475458 If True, it must be the case that `x1 == x2`. (Default: False.)
476- :param last_dim_is_batch: If True, treat the last dimension
477- of `x1` and `x2` as another batch dimension.
478- (Useful for additive structure over the dimensions). (Default: False.)
479459
480460 :return: An object that will lazily evaluate to the kernel matrix or vector.
481461 The shape depends on the kernel's evaluation mode:
482462
483463 * `full_covar`: `... x N x M`
484- * `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
485464 * `diag`: `... x N`
486- * `diag` with `last_dim_is_batch=True`: `... x K x N`
487465 """
488- if last_dim_is_batch :
489- warnings .warn (
490- "The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. "
491- "If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, "
492- 'please update your code according to the "Kernels with Additive or Product Structure" '
493- "tutorial in the GPyTorch docs." ,
494- DeprecationWarning ,
495- )
496-
497466 x1_ , x2_ = x1 , x2
498467
499468 # Select the active dimensions
@@ -523,7 +492,7 @@ def __call__(
523492 )
524493
525494 if diag :
526- res = super (Kernel , self ).__call__ (x1_ , x2_ , diag = True , last_dim_is_batch = last_dim_is_batch , ** params )
495+ res = super (Kernel , self ).__call__ (x1_ , x2_ , diag = True , ** params )
527496 # Did this Kernel eat the diag option?
528497 # If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
529498 if not isinstance (res , LazyEvaluatedKernelTensor ):
@@ -533,11 +502,9 @@ def __call__(
533502
534503 else :
535504 if settings .lazily_evaluate_kernels .on ():
536- res = LazyEvaluatedKernelTensor (x1_ , x2_ , kernel = self , last_dim_is_batch = last_dim_is_batch , ** params )
505+ res = LazyEvaluatedKernelTensor (x1_ , x2_ , kernel = self , ** params )
537506 else :
538- res = to_linear_operator (
539- super (Kernel , self ).__call__ (x1_ , x2_ , last_dim_is_batch = last_dim_is_batch , ** params )
540- )
507+ res = to_linear_operator (super (Kernel , self ).__call__ (x1_ , x2_ , ** params ))
541508 return res
542509
543510 def __getstate__ (self ):
0 commit comments