Skip to content

Commit ade5db8

Browse files
authored
Added Type hints and exceptions in kernels (#1802)
1 parent 536a279 commit ade5db8

35 files changed

+323
-51
lines changed

gpytorch/kernels/additive_structure_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from typing import Optional, Tuple
4+
35
from .kernel import Kernel
46

57

@@ -39,7 +41,9 @@ def is_stationary(self) -> bool:
3941
"""
4042
return self.base_kernel.is_stationary
4143

42-
def __init__(self, base_kernel, num_dims, active_dims=None):
44+
def __init__(
45+
self, base_kernel: Kernel, num_dims: int, active_dims: Optional[Tuple[int, ...]] = None,
46+
):
4347
super(AdditiveStructureKernel, self).__init__(active_dims=active_dims)
4448
self.base_kernel = base_kernel
4549
self.num_dims = num_dims

gpytorch/kernels/arc_kernel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22

33
from math import pi
4-
from typing import Optional
4+
from typing import Callable, Optional
55

66
import torch
77

@@ -98,8 +98,8 @@ class ArcKernel(Kernel):
9898

9999
def __init__(
100100
self,
101-
base_kernel,
102-
delta_func: Optional = None,
101+
base_kernel: Kernel,
102+
delta_func: Optional[Callable] = None,
103103
angle_prior: Optional[Prior] = None,
104104
radius_prior: Optional[Prior] = None,
105105
**kwargs,
@@ -122,6 +122,8 @@ def __init__(
122122
name="raw_angle", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, self.last_dim)),
123123
)
124124
if angle_prior is not None:
125+
if not isinstance(angle_prior, Prior):
126+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(angle_prior).__name__)
125127
self.register_prior(
126128
"angle_prior", angle_prior, lambda m: m.angle, lambda m, v: m._set_angle(v),
127129
)
@@ -133,6 +135,8 @@ def __init__(
133135
)
134136

135137
if radius_prior is not None:
138+
if not isinstance(radius_prior, Prior):
139+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(radius_prior).__name__)
136140
self.register_prior(
137141
"radius_prior", radius_prior, lambda m: m.radius, lambda m, v: m._set_radius(v),
138142
)

gpytorch/kernels/cosine_kernel.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#!/usr/bin/env python3
22

33
import math
4+
from typing import Optional
45

56
import torch
67

7-
from ..constraints import Positive
8+
from ..constraints import Interval, Positive
9+
from ..priors import Prior
810
from .kernel import Kernel
911

1012

@@ -56,7 +58,12 @@ class CosineKernel(Kernel):
5658

5759
is_stationary = True
5860

59-
def __init__(self, period_length_prior=None, period_length_constraint=None, **kwargs):
61+
def __init__(
62+
self,
63+
period_length_prior: Optional[Prior] = None,
64+
period_length_constraint: Optional[Interval] = None,
65+
**kwargs,
66+
):
6067
super(CosineKernel, self).__init__(**kwargs)
6168

6269
self.register_parameter(
@@ -67,6 +74,8 @@ def __init__(self, period_length_prior=None, period_length_constraint=None, **kw
6774
period_length_constraint = Positive()
6875

6976
if period_length_prior is not None:
77+
if not isinstance(period_length_prior, Prior):
78+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(period_length_prior).__name__)
7079
self.register_prior(
7180
"period_length_prior",
7281
period_length_prior,

gpytorch/kernels/cylindrical_kernel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self,
4343
num_angular_weights: int,
4444
radial_base_kernel: Kernel,
45-
eps: Optional[int] = 1e-6,
45+
eps: Optional[float] = 1e-6,
4646
angular_weights_prior: Optional[Prior] = None,
4747
angular_weights_constraint: Optional[Interval] = None,
4848
alpha_prior: Optional[Prior] = None,
@@ -76,15 +76,21 @@ def __init__(
7676
self.register_constraint("raw_beta", beta_constraint)
7777

7878
if angular_weights_prior is not None:
79+
if not isinstance(angular_weights_prior, Prior):
80+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(angular_weights_prior).__name__)
7981
self.register_prior(
8082
"angular_weights_prior",
8183
angular_weights_prior,
8284
lambda m: m.angular_weights,
8385
lambda m, v: m._set_angular_weights(v),
8486
)
8587
if alpha_prior is not None:
88+
if not isinstance(alpha_prior, Prior):
89+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(alpha_prior).__name__)
8690
self.register_prior("alpha_prior", alpha_prior, lambda m: m.alpha, lambda m, v: m._set_alpha(v))
8791
if beta_prior is not None:
92+
if not isinstance(beta_prior, Prior):
93+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(beta_prior).__name__)
8894
self.register_prior("beta_prior", beta_prior, lambda m: m.beta, lambda m, v: m._set_beta(v))
8995

9096
@property

gpytorch/kernels/distributional_input_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/usr/bin/env python3
22

3+
from typing import Callable
4+
35
import torch
46

57
from .kernel import Kernel
@@ -22,7 +24,9 @@ class DistributionalInputKernel(Kernel):
2224
"""
2325
has_lengthscale = True
2426

25-
def __init__(self, distance_function, **kwargs):
27+
def __init__(
28+
self, distance_function: Callable, **kwargs,
29+
):
2630
super(DistributionalInputKernel, self).__init__(**kwargs)
2731
if distance_function is None:
2832
raise NotImplementedError("DistributionalInputKernel requires a distance function.")

gpytorch/kernels/grid_interpolation_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ def __init__(
7373
self,
7474
base_kernel: Kernel,
7575
grid_size: Union[int, List[int]],
76-
num_dims: int = None,
76+
num_dims: Optional[int] = None,
7777
grid_bounds: Optional[Tuple[float, float]] = None,
78-
active_dims: Tuple[int, ...] = None,
78+
active_dims: Optional[Tuple[int, ...]] = None,
7979
):
8080
has_initialized_grid = 0
8181
grid_is_dynamic = True

gpytorch/kernels/grid_kernel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
from typing import List
3+
from typing import Optional
44

55
import torch
66
from torch import Tensor
@@ -46,7 +46,11 @@ class GridKernel(Kernel):
4646
is_stationary = True
4747

4848
def __init__(
49-
self, base_kernel: Kernel, grid: List[Tensor], interpolation_mode: bool = False, active_dims: bool = None
49+
self,
50+
base_kernel: Kernel,
51+
grid: Tensor,
52+
interpolation_mode: Optional[bool] = False,
53+
active_dims: Optional[bool] = None,
5054
):
5155
if not base_kernel.is_stationary:
5256
raise RuntimeError("The base_kernel for GridKernel must be stationary.")

gpytorch/kernels/index_kernel.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
#!/usr/bin/env python3
22

3+
from typing import Optional
4+
35
import torch
46

5-
from ..constraints import Positive
7+
from ..constraints import Interval, Positive
68
from ..lazy import DiagLazyTensor, InterpolatedLazyTensor, PsdSumLazyTensor, RootLazyTensor
9+
from ..priors import Prior
710
from ..utils.broadcasting import _mul_broadcast_shape
811
from .kernel import Kernel
912

@@ -43,7 +46,14 @@ class IndexKernel(Kernel):
4346
The element-wise log of the :math:`\mathbf v` vector.
4447
"""
4548

46-
def __init__(self, num_tasks, rank=1, prior=None, var_constraint=None, **kwargs):
49+
def __init__(
50+
self,
51+
num_tasks: int,
52+
rank: Optional[int] = 1,
53+
prior: Optional[Prior] = None,
54+
var_constraint: Optional[Interval] = None,
55+
**kwargs,
56+
):
4757
if rank > num_tasks:
4858
raise RuntimeError("Cannot create a task covariance matrix larger than the number of tasks")
4959
super().__init__(**kwargs)
@@ -56,6 +66,8 @@ def __init__(self, num_tasks, rank=1, prior=None, var_constraint=None, **kwargs)
5666
)
5767
self.register_parameter(name="raw_var", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks)))
5868
if prior is not None:
69+
if not isinstance(prior, Prior):
70+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(prior).__name__)
5971
self.register_prior("IndexKernelPrior", prior, lambda m: m._eval_covar_matrix())
6072

6173
self.register_constraint("raw_var", var_constraint)

gpytorch/kernels/inducing_point_kernel.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,29 @@
22

33
import copy
44
import math
5+
from typing import Optional, Tuple
56

67
import torch
8+
from torch import Tensor
79

810
from .. import settings
911
from ..distributions import MultivariateNormal
1012
from ..lazy import DiagLazyTensor, LowRankRootAddedDiagLazyTensor, LowRankRootLazyTensor, MatmulLazyTensor, delazify
13+
from ..likelihoods import Likelihood
1114
from ..mlls import InducingPointKernelAddedLossTerm
1215
from ..models import exact_prediction_strategies
1316
from ..utils.cholesky import psd_safe_cholesky
1417
from .kernel import Kernel
1518

1619

1720
class InducingPointKernel(Kernel):
18-
def __init__(self, base_kernel, inducing_points, likelihood, active_dims=None):
21+
def __init__(
22+
self,
23+
base_kernel: Kernel,
24+
inducing_points: Tensor,
25+
likelihood: Likelihood,
26+
active_dims: Optional[Tuple[int, ...]] = None,
27+
):
1928
super(InducingPointKernel, self).__init__(active_dims=active_dims)
2029
self.base_kernel = base_kernel
2130
self.likelihood = likelihood

gpytorch/kernels/kernel.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
import warnings
44
from abc import abstractmethod
55
from copy import deepcopy
6+
from typing import Optional, Tuple
67

78
import torch
89
from torch.nn import ModuleList
910

1011
from .. import settings
11-
from ..constraints import Positive
12+
from ..constraints import Interval, Positive
1213
from ..lazy import LazyEvaluatedKernelTensor, ZeroLazyTensor, delazify, lazify
1314
from ..models import exact_prediction_strategies
1415
from ..module import Module
16+
from ..priors import Prior
1517
from ..utils.broadcasting import _mul_broadcast_shape
1618

1719

@@ -131,12 +133,12 @@ class Kernel(Module):
131133

132134
def __init__(
133135
self,
134-
ard_num_dims=None,
135-
batch_shape=torch.Size([]),
136-
active_dims=None,
137-
lengthscale_prior=None,
138-
lengthscale_constraint=None,
139-
eps=1e-6,
136+
ard_num_dims: Optional[int] = None,
137+
batch_shape: Optional[torch.Size] = torch.Size([]),
138+
active_dims: Optional[Tuple[int, ...]] = None,
139+
lengthscale_prior: Optional[Prior] = None,
140+
lengthscale_constraint: Optional[Interval] = None,
141+
eps: Optional[float] = 1e-6,
140142
**kwargs,
141143
):
142144
super(Kernel, self).__init__()
@@ -167,6 +169,8 @@ def __init__(
167169
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, lengthscale_num_dims)),
168170
)
169171
if lengthscale_prior is not None:
172+
if not isinstance(lengthscale_prior, Prior):
173+
raise TypeError("Expected gpytorch.priors.Prior but got " + type(lengthscale_prior).__name__)
170174
self.register_prior(
171175
"lengthscale_prior", lengthscale_prior, lambda m: m.lengthscale, lambda m, v: m._set_lengthscale(v)
172176
)

0 commit comments

Comments
 (0)