Skip to content

Commit 8979210

Browse files
authored
Merge pull request #2388 from douglas-boubert/grad_kernels_only_create_covariance_if_needed
Stop rbf_kernel_grad and rbf_kernel_gradgrad creating the full covariance matrix unnecessarily
2 parents 090d6e1 + eed36a4 commit 8979210

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

gpytorch/kernels/rbf_kernel_grad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def forward(self, x1, x2, diag=False, **params):
5757
n1, d = x1.shape[-2:]
5858
n2 = x2.shape[-2]
5959

60-
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
61-
6260
if not diag:
61+
K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)
62+
6363
# Scale the inputs by the lengthscale (for stability)
6464
x1_ = x1.div(self.lengthscale)
6565
x2_ = x2.div(self.lengthscale)

gpytorch/kernels/rbf_kernel_gradgrad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ def forward(self, x1, x2, diag=False, **params):
5757
n1, d = x1.shape[-2:]
5858
n2 = x2.shape[-2]
5959

60-
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)
61-
6260
if not diag:
61+
K = torch.zeros(*batch_shape, n1 * (2 * d + 1), n2 * (2 * d + 1), device=x1.device, dtype=x1.dtype)
62+
6363
# Scale the inputs by the lengthscale (for stability)
6464
x1_ = x1.div(self.lengthscale)
6565
x2_ = x2.div(self.lengthscale)

0 commit comments

Comments
 (0)