Skip to content

Commit 28ee4ca

Browse files
Merge branch 'master' into dGPFantasize
2 parents 2f7a3cf + 527546e commit 28ee4ca

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

gpytorch/kernels/kernel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525

2626
def sq_dist(x1, x2, x1_eq_x2=False):
27+
"""Equivalent to the square of `torch.cdist` with p=2."""
2728
# TODO: use torch squared cdist once implemented: https://github.com/pytorch/pytorch/pull/25799
2829
adjustment = x1.mean(-2, keepdim=True)
2930
x1 = x1 - adjustment
@@ -49,7 +50,12 @@ def sq_dist(x1, x2, x1_eq_x2=False):
4950

5051

5152
def dist(x1, x2, x1_eq_x2=False):
52-
# TODO: use torch cdist once implementation is improved: https://github.com/pytorch/pytorch/pull/25799
53+
"""
54+
Equivalent to `torch.cdist` with p=2, but clamps the minimum element to 1e-15.
55+
"""
56+
if not x1_eq_x2:
57+
res = torch.cdist(x1, x2)
58+
return res.clamp_min(1e-15)
5359
res = sq_dist(x1, x2, x1_eq_x2=x1_eq_x2)
5460
return res.clamp_min_(1e-30).sqrt_()
5561

0 commit comments

Comments
 (0)