Skip to content

Commit f51982e

Browse files
committed
Kernel batch size recurses through module lists.
[Fixes #1672]
1 parent b237019 commit f51982e

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

gpytorch/kernels/kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ def covar_dist(
331331
return res
332332

333333
def named_sub_kernels(self):
334-
for name, module in self._modules.items():
335-
if isinstance(module, Kernel):
334+
for name, module in self.named_modules():
335+
if module is not self and isinstance(module, Kernel):
336336
yield name, module
337337

338338
def num_outputs_per_input(self, x1, x2):

test/kernels/test_additive_and_product_kernels.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,23 @@ def test_computes_product_of_radial_basis_function(self):
5656
res = kernel(a, b).evaluate()
5757
self.assertLess(torch.norm(res - actual), 2e-5)
5858

59+
def test_computes_product_of_radial_basis_function_batch(self):
60+
a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
61+
b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)
62+
lengthscale = 2
63+
64+
kernel_1 = RBFKernel(batch_shape=torch.Size([4])).initialize(lengthscale=lengthscale)
65+
kernel_2 = RBFKernel().initialize(lengthscale=lengthscale)
66+
kernel = kernel_1 * kernel_2
67+
68+
actual = torch.tensor([[16, 4], [4, 0], [64, 36]], dtype=torch.float)
69+
actual = actual.mul_(-0.5).div_(lengthscale ** 2).exp() ** 2
70+
actual = actual.repeat(4, 1, 1)
71+
72+
kernel.eval()
73+
res = kernel(a, b).evaluate()
74+
self.assertLess(torch.norm(res - actual), 2e-5)
75+
5976
def test_computes_sum_of_radial_basis_function(self):
6077
a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
6178
b = torch.tensor([0, 2], dtype=torch.float).view(2, 1)

0 commit comments

Comments
 (0)