|
3 | 3 | import warnings |
4 | 4 | from abc import abstractmethod |
5 | 5 | from copy import deepcopy |
| 6 | +from typing import Optional, Tuple |
6 | 7 |
|
7 | 8 | import torch |
8 | 9 | from torch.nn import ModuleList |
9 | 10 |
|
10 | 11 | from .. import settings |
11 | | -from ..constraints import Positive |
| 12 | +from ..constraints import Interval, Positive |
12 | 13 | from ..lazy import LazyEvaluatedKernelTensor, ZeroLazyTensor, delazify, lazify |
13 | 14 | from ..models import exact_prediction_strategies |
14 | 15 | from ..module import Module |
| 16 | +from ..priors import Prior |
15 | 17 | from ..utils.broadcasting import _mul_broadcast_shape |
16 | 18 |
|
17 | 19 |
|
@@ -131,12 +133,12 @@ class Kernel(Module): |
131 | 133 |
|
132 | 134 | def __init__( |
133 | 135 | self, |
134 | | - ard_num_dims=None, |
135 | | - batch_shape=torch.Size([]), |
136 | | - active_dims=None, |
137 | | - lengthscale_prior=None, |
138 | | - lengthscale_constraint=None, |
139 | | - eps=1e-6, |
| 136 | + ard_num_dims: Optional[int] = None, |
| 137 | + batch_shape: Optional[torch.Size] = torch.Size([]), |
| 138 | + active_dims: Optional[Tuple[int, ...]] = None, |
| 139 | + lengthscale_prior: Optional[Prior] = None, |
| 140 | + lengthscale_constraint: Optional[Interval] = None, |
| 141 | + eps: Optional[float] = 1e-6, |
140 | 142 | **kwargs, |
141 | 143 | ): |
142 | 144 | super(Kernel, self).__init__() |
@@ -167,6 +169,8 @@ def __init__( |
167 | 169 | parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, lengthscale_num_dims)), |
168 | 170 | ) |
169 | 171 | if lengthscale_prior is not None: |
| 172 | + if not isinstance(lengthscale_prior, Prior): |
| 173 | + raise TypeError("Expected gpytorch.priors.Prior but got " + type(lengthscale_prior).__name__) |
170 | 174 | self.register_prior( |
171 | 175 | "lengthscale_prior", lengthscale_prior, lambda m: m.lengthscale, lambda m, v: m._set_lengthscale(v) |
172 | 176 | ) |
|
0 commit comments