|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +from typing import Optional, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch import Tensor |
| 7 | + |
| 8 | +from ..constraints import Interval, Positive |
| 9 | +from ..priors import Prior |
| 10 | +from .kernel import Kernel |
| 11 | + |
| 12 | + |
| 13 | +class ConstantKernel(Kernel): |
| 14 | + """ |
| 15 | + Constant covariance kernel for the probabilistic inference of constant coefficients. |
| 16 | +
|
| 17 | + ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`. |
| 18 | + The prior variance of the constant is optimized during the GP hyper-parameter |
| 19 | + optimization stage. The actual value of the constant is computed (implicitly) using |
| 20 | + the linear algebraic approaches for the computation of GP samples and posteriors. |
| 21 | +
|
| 22 | + The constant kernel `k_constant` is most useful as a modification of an arbitrary |
| 23 | + base kernel `k_base`: |
| 24 | + 1) Additive constants: The modification `k_base + k_constant` allows the GP to |
| 25 | + infer a non-zero asymptotic value far from the training data, which generally |
| 26 | + leads to more accurate extrapolation. Notably, the uncertainty in this constant |
| 27 | + value affects the posterior covariances through the posterior inference equations. |
| 28 | + This is not the case when a constant prior mean is not used, since the prior mean |
| 29 | + does not show up the posterior covariance and is regularized by the log-determinant |
| 30 | + during the optimization of the marginal likelihood. |
| 31 | + 2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to |
| 32 | + modulate the variance of the kernel `k_base`, and is mathematically identical to |
| 33 | + `ScaleKernel(base_kernel)` with the same constant. |
| 34 | + """ |
| 35 | + |
| 36 | + has_lengthscale = False |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + batch_shape: Optional[torch.Size] = None, |
| 41 | + constant_prior: Optional[Prior] = None, |
| 42 | + constant_constraint: Optional[Interval] = None, |
| 43 | + active_dims: Optional[Tuple[int, ...]] = None, |
| 44 | + ): |
| 45 | + """Constructor of ConstantKernel. |
| 46 | +
|
| 47 | + Args: |
| 48 | + batch_shape: The batch shape of the kernel. |
| 49 | + constant_prior: Prior over the constant parameter. |
| 50 | + constant_constraint: Constraint to place on constant parameter. |
| 51 | + active_dims: The dimensions of the input with which to evaluate the kernel. |
| 52 | + This is mute for the constant kernel, but added for compatability with |
| 53 | + the Kernel API. |
| 54 | + """ |
| 55 | + super().__init__(batch_shape=batch_shape, active_dims=active_dims) |
| 56 | + |
| 57 | + self.register_parameter( |
| 58 | + name="raw_constant", |
| 59 | + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), |
| 60 | + ) |
| 61 | + |
| 62 | + if constant_prior is not None: |
| 63 | + if not isinstance(constant_prior, Prior): |
| 64 | + raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__) |
| 65 | + self.register_prior( |
| 66 | + "constant_prior", |
| 67 | + constant_prior, |
| 68 | + lambda m: m.constant, |
| 69 | + lambda m, v: m._set_constant(v), |
| 70 | + ) |
| 71 | + |
| 72 | + if constant_constraint is None: |
| 73 | + constant_constraint = Positive() |
| 74 | + self.register_constraint("raw_constant", constant_constraint) |
| 75 | + |
| 76 | + @property |
| 77 | + def constant(self) -> Tensor: |
| 78 | + return self.raw_constant_constraint.transform(self.raw_constant) |
| 79 | + |
| 80 | + @constant.setter |
| 81 | + def constant(self, value: Tensor) -> None: |
| 82 | + self._set_constant(value) |
| 83 | + |
| 84 | + def _set_constant(self, value: Tensor) -> None: |
| 85 | + value = value.view(*self.batch_shape, 1) |
| 86 | + self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value)) |
| 87 | + |
| 88 | + def forward( |
| 89 | + self, |
| 90 | + x1: Tensor, |
| 91 | + x2: Tensor, |
| 92 | + diag: Optional[bool] = False, |
| 93 | + last_dim_is_batch: Optional[bool] = False, |
| 94 | + ) -> Tensor: |
| 95 | + """Evaluates the constant kernel. |
| 96 | +
|
| 97 | + Args: |
| 98 | + x1: First input tensor of shape (batch_shape x n1 x d). |
| 99 | + x2: Second input tensor of shape (batch_shape x n2 x d). |
| 100 | + diag: If True, returns the diagonal of the covariance matrix. |
| 101 | + last_dim_is_batch: If True, the last dimension of size `d` of the input |
| 102 | + tensors are treated as a batch dimension. |
| 103 | +
|
| 104 | + Returns: |
| 105 | + A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of |
| 106 | + constant covariance values if diag is False, resp. True. |
| 107 | + """ |
| 108 | + if last_dim_is_batch: |
| 109 | + x1 = x1.transpose(-1, -2).unsqueeze(-1) |
| 110 | + x2 = x2.transpose(-1, -2).unsqueeze(-1) |
| 111 | + |
| 112 | + dtype = torch.promote_types(x1.dtype, x2.dtype) |
| 113 | + batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) |
| 114 | + shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],)) |
| 115 | + constant = self.constant.to(dtype=dtype, device=x1.device) |
| 116 | + |
| 117 | + if not diag: |
| 118 | + constant = constant.unsqueeze(-1) |
| 119 | + |
| 120 | + if last_dim_is_batch: |
| 121 | + constant = constant.unsqueeze(-1) |
| 122 | + |
| 123 | + return constant.expand(shape) |
0 commit comments