Skip to content

Commit d19f52e

Browse files
authored
Merge pull request #2378 from j-wilson/patch_matern
Minor patch to Matern covariances
2 parents f73fa7d + 776bbeb commit d19f52e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

gpytorch/functions/matern_covariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def forward(ctx, x1, x2, lengthscale, nu, dist_func):
1313
# Subtract mean for numerical stability. Won't affect computations
1414
# because covariance matrix is stationary.
1515
needs_grad = any(ctx.needs_input_grad)
16-
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
16+
mean = x1.mean(dim=-2, keepdim=True)
1717
x1_ = (x1 - mean).div(lengthscale)
1818
x2_ = (x2 - mean).div(lengthscale)
1919
scaled_unitless_dist = dist_func(x1_, x2_).mul_(math.sqrt(2 * nu))

gpytorch/kernels/matern_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def forward(self, x1, x2, diag=False, **params):
9292
or params.get("last_dim_is_batch", False)
9393
or trace_mode.on()
9494
):
95-
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
95+
mean = x1.mean(dim=-2, keepdim=True)
9696

9797
x1_ = (x1 - mean).div(self.lengthscale)
9898
x2_ = (x2 - mean).div(self.lengthscale)

0 commit comments

Comments
 (0)