Skip to content

Commit dc61fb1

Browse files
committed
[Breaking Change] remove last_dim_is_batch from remaining kernels
1 parent fe08545 commit dc61fb1

25 files changed

+25
-388
lines changed

gpytorch/kernels/constant_kernel.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,18 @@ def forward(
9090
x1: Tensor,
9191
x2: Tensor,
9292
diag: Optional[bool] = False,
93-
last_dim_is_batch: Optional[bool] = False,
9493
) -> Tensor:
9594
"""Evaluates the constant kernel.
9695
9796
Args:
9897
x1: First input tensor of shape (batch_shape x n1 x d).
9998
x2: Second input tensor of shape (batch_shape x n2 x d).
10099
diag: If True, returns the diagonal of the covariance matrix.
101-
last_dim_is_batch: If True, the last dimension of size `d` of the input
102-
tensors are treated as a batch dimension.
103100
104101
Returns:
105102
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
106103
constant covariance values if diag is False, resp. True.
107104
"""
108-
if last_dim_is_batch:
109-
x1 = x1.transpose(-1, -2).unsqueeze(-1)
110-
x2 = x2.transpose(-1, -2).unsqueeze(-1)
111-
112105
dtype = torch.promote_types(x1.dtype, x2.dtype)
113106
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
114107
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
@@ -117,7 +110,4 @@ def forward(
117110
if not diag:
118111
constant = constant.unsqueeze(-1)
119112

120-
if last_dim_is_batch:
121-
constant = constant.unsqueeze(-1)
122-
123113
return constant.expand(shape)

gpytorch/kernels/kernel.py

Lines changed: 5 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,7 @@ def _set_lengthscale(self, value: Tensor):
231231
self.initialize(raw_lengthscale=self.raw_lengthscale_constraint.inverse_transform(value))
232232

233233
@abstractmethod
234-
def forward(
235-
self, x1: Tensor, x2: Tensor, diag: bool = False, last_dim_is_batch: bool = False, **params
236-
) -> Union[Tensor, LinearOperator]:
234+
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Union[Tensor, LinearOperator]:
237235
r"""
238236
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
239237
This method should be implemented by all Kernel subclasses.
@@ -242,16 +240,11 @@ def forward(
242240
:param x2: Second set of data (... x M x D).
243241
:param diag: Should the Kernel compute the whole kernel, or just the diag?
244242
If True, it must be the case that `x1 == x2`. (Default: False.)
245-
:param last_dim_is_batch: If True, treat the last dimension
246-
of `x1` and `x2` as another batch dimension.
247-
(Useful for additive structure over the dimensions). (Default: False.)
248243
249244
:return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
250245
251246
* `full_covar`: `... x N x M`
252-
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
253247
* `diag`: `... x N`
254-
* `diag` with `last_dim_is_batch=True`: `... x K x N`
255248
"""
256249
raise NotImplementedError()
257250

@@ -314,7 +307,6 @@ def covar_dist(
314307
x1: Tensor,
315308
x2: Tensor,
316309
diag: bool = False,
317-
last_dim_is_batch: bool = False,
318310
square_dist: bool = False,
319311
**params,
320312
) -> Tensor:
@@ -326,22 +318,13 @@ def covar_dist(
326318
:param x2: Second set of data (... x M x D).
327319
:param diag: Should the Kernel compute the whole kernel, or just the diag?
328320
If True, it must be the case that `x1 == x2`. (Default: False.)
329-
:param last_dim_is_batch: If True, treat the last dimension
330-
of `x1` and `x2` as another batch dimension.
331-
(Useful for additive structure over the dimensions). (Default: False.)
332321
:param square_dist:
333322
If True, returns the squared distance rather than the standard distance. (Default: False.)
334323
:return: The kernel matrix or vector. The shape depends on the kernel's evaluation mode:
335324
336325
* `full_covar`: `... x N x M`
337-
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
338326
* `diag`: `... x N`
339-
* `diag` with `last_dim_is_batch=True`: `... x K x N`
340327
"""
341-
if last_dim_is_batch:
342-
x1 = x1.transpose(-1, -2).unsqueeze(-1)
343-
x2 = x2.transpose(-1, -2).unsqueeze(-1)
344-
345328
x1_eq_x2 = torch.equal(x1, x2)
346329
res = None
347330

@@ -457,7 +440,7 @@ def sub_kernels(self) -> Iterable[Kernel]:
457440
yield kernel
458441

459442
def __call__(
460-
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, last_dim_is_batch: bool = False, **params
443+
self, x1: Tensor, x2: Optional[Tensor] = None, diag: bool = False, **params
461444
) -> Union[LazyEvaluatedKernelTensor, LinearOperator, Tensor]:
462445
r"""
463446
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
@@ -473,27 +456,13 @@ def __call__(
473456
(If `None`, then `x2` is set to `x1`.)
474457
:param diag: Should the Kernel compute the whole kernel, or just the diag?
475458
If True, it must be the case that `x1 == x2`. (Default: False.)
476-
:param last_dim_is_batch: If True, treat the last dimension
477-
of `x1` and `x2` as another batch dimension.
478-
(Useful for additive structure over the dimensions). (Default: False.)
479459
480460
:return: An object that will lazily evaluate to the kernel matrix or vector.
481461
The shape depends on the kernel's evaluation mode:
482462
483463
* `full_covar`: `... x N x M`
484-
* `full_covar` with `last_dim_is_batch=True`: `... x K x N x M`
485464
* `diag`: `... x N`
486-
* `diag` with `last_dim_is_batch=True`: `... x K x N`
487465
"""
488-
if last_dim_is_batch:
489-
warnings.warn(
490-
"The last_dim_is_batch argument is deprecated, and will be removed in GPyTorch 2.0. "
491-
"If you are using it as part of AdditiveStructureKernel or ProductStructureKernel, "
492-
'please update your code according to the "Kernels with Additive or Product Structure" '
493-
"tutorial in the GPyTorch docs.",
494-
DeprecationWarning,
495-
)
496-
497466
x1_, x2_ = x1, x2
498467

499468
# Select the active dimensions
@@ -523,7 +492,7 @@ def __call__(
523492
)
524493

525494
if diag:
526-
res = super(Kernel, self).__call__(x1_, x2_, diag=True, last_dim_is_batch=last_dim_is_batch, **params)
495+
res = super(Kernel, self).__call__(x1_, x2_, diag=True, **params)
527496
# Did this Kernel eat the diag option?
528497
# If it does not return a LazyEvaluatedKernelTensor, we can call diag on the output
529498
if not isinstance(res, LazyEvaluatedKernelTensor):
@@ -533,11 +502,9 @@ def __call__(
533502

534503
else:
535504
if settings.lazily_evaluate_kernels.on():
536-
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, last_dim_is_batch=last_dim_is_batch, **params)
505+
res = LazyEvaluatedKernelTensor(x1_, x2_, kernel=self, **params)
537506
else:
538-
res = to_linear_operator(
539-
super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
540-
)
507+
res = to_linear_operator(super(Kernel, self).__call__(x1_, x2_, **params))
541508
return res
542509

543510
def __getstate__(self):

gpytorch/kernels/linear_kernel.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,8 @@ def _set_variance(self, value: Union[float, Tensor]):
8585
value = torch.as_tensor(value).to(self.raw_variance)
8686
self.initialize(raw_variance=self.raw_variance_constraint.inverse_transform(value))
8787

88-
def forward(
89-
self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, last_dim_is_batch: Optional[bool] = False, **params
90-
) -> LinearOperator:
88+
def forward(self, x1: Tensor, x2: Tensor, diag: Optional[bool] = False, **params) -> LinearOperator:
9189
x1_ = x1 * self.variance.sqrt()
92-
if last_dim_is_batch:
93-
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
9490

9591
if x1.size() == x2.size() and torch.equal(x1, x2):
9692
# Use RootLinearOperator when x1 == x2 for efficiency when composing
@@ -99,9 +95,6 @@ def forward(
9995

10096
else:
10197
x2_ = x2 * self.variance.sqrt()
102-
if last_dim_is_batch:
103-
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)
104-
10598
prod = MatmulLinearOperator(x1_, x2_.transpose(-2, -1))
10699

107100
if diag:

gpytorch/kernels/matern_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def forward(self, x1, x2, diag=False, **params):
8989
or x2.requires_grad
9090
or (self.ard_num_dims is not None and self.ard_num_dims > 1)
9191
or diag
92-
or params.get("last_dim_is_batch", False)
9392
or trace_mode.on()
9493
):
9594
mean = x1.mean(dim=-2, keepdim=True)

gpytorch/kernels/multitask_kernel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ def __init__(
4343
self.data_covar_module = data_covar_module
4444
self.num_tasks = num_tasks
4545

46-
def forward(self, x1, x2, diag=False, last_dim_is_batch=False, **params):
47-
if last_dim_is_batch:
48-
raise RuntimeError("MultitaskKernel does not accept the last_dim_is_batch argument.")
46+
def forward(self, x1, x2, diag=False, **params):
4947
covar_i = self.task_covar_module.covar_matrix
5048
if len(x1.shape[:-2]):
5149
covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)

gpytorch/kernels/periodic_kernel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,23 +124,20 @@ def _set_period_length(self, value):
124124
self.initialize(raw_period_length=self.raw_period_length_constraint.inverse_transform(value))
125125

126126
def forward(self, x1, x2, diag=False, **params):
127-
# Pop this argument so that we can manually sum over dimensions
128-
last_dim_is_batch = params.pop("last_dim_is_batch", False)
129127
# Get lengthscale
130128
lengthscale = self.lengthscale
131129

132130
x1_ = x1.div(self.period_length / math.pi)
133131
x2_ = x2.div(self.period_length / math.pi)
134-
# We are automatically overriding last_dim_is_batch here so that we can manually sum over dimensions.
135-
diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params)
132+
diff = self.covar_dist(
133+
x1_.transpose(-1, -2).unsqueeze(-1), x2_.transpose(-1, -2).unsqueeze(-1), diag=diag, **params
134+
) # A ... x D x N x N kernel
136135

137136
if diag:
138137
lengthscale = lengthscale[..., 0, :, None]
139138
else:
140139
lengthscale = lengthscale[..., 0, :, None, None]
141140
exp_term = diff.sin().pow(2.0).div(lengthscale).mul(-2.0)
142-
143-
if not last_dim_is_batch:
144-
exp_term = exp_term.sum(dim=(-2 if diag else -3))
141+
exp_term = exp_term.sum(dim=(-2 if diag else -3))
145142

146143
return exp_term.exp()

gpytorch/kernels/piecewise_polynomial_kernel.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,13 @@ def __init__(self, q: Optional[int] = 2, **kwargs):
101101
raise ValueError("q expected to be 0, 1, 2 or 3")
102102
self.q = q
103103

104-
def forward(self, x1: Tensor, x2: Tensor, last_dim_is_batch: bool = False, diag: bool = False, **params) -> Tensor:
104+
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **params) -> Tensor:
105105
x1_ = x1.div(self.lengthscale)
106106
x2_ = x2.div(self.lengthscale)
107-
if last_dim_is_batch is True:
108-
D = x1.shape[1]
109-
else:
110-
D = x1.shape[-1]
107+
D = x1.shape[-1]
111108
j = math.floor(D / 2.0) + self.q + 1
112-
if last_dim_is_batch and diag:
113-
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True, diag=True)
114-
elif diag:
109+
if diag:
115110
r = self.covar_dist(x1_, x2_, diag=True)
116-
elif last_dim_is_batch:
117-
r = self.covar_dist(x1_, x2_, last_dim_is_batch=True)
118111
else:
119112
r = self.covar_dist(x1_, x2_)
120113
cov_matrix = _fmax(r, j, self.q) * _get_cov(r, j, self.q)

gpytorch/kernels/polynomial_kernel.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,10 @@ def forward(
8181
x1: torch.Tensor,
8282
x2: torch.Tensor,
8383
diag: Optional[bool] = False,
84-
last_dim_is_batch: Optional[bool] = False,
8584
**params,
8685
) -> torch.Tensor:
8786
offset = self.offset.view(*self.batch_shape, 1, 1)
8887

89-
if last_dim_is_batch:
90-
x1 = x1.transpose(-1, -2).unsqueeze(-1)
91-
x2 = x2.transpose(-1, -2).unsqueeze(-1)
92-
9388
if diag:
9489
return ((x1 * x2).sum(dim=-1) + self.offset).pow(self.power)
9590

gpytorch/kernels/polynomial_kernel_grad.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ def forward(
1313
x1: torch.Tensor,
1414
x2: torch.Tensor,
1515
diag: Optional[bool] = False,
16-
last_dim_is_batch: Optional[bool] = False,
1716
**params,
1817
) -> torch.Tensor:
1918
offset = self.offset.view(*self.batch_shape, 1, 1)

gpytorch/kernels/rbf_kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def forward(self, x1, x2, diag=False, **params):
7171
or x2.requires_grad
7272
or (self.ard_num_dims is not None and self.ard_num_dims > 1)
7373
or diag
74-
or params.get("last_dim_is_batch", False)
7574
or trace_mode.on()
7675
):
7776
x1_ = x1.div(self.lengthscale)

0 commit comments

Comments
 (0)