|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
| 4 | + |
| 5 | +from typing import Any, Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +from botorch.exceptions import UnsupportedError |
| 9 | +from gpytorch.constraints import Interval, Positive |
| 10 | +from gpytorch.kernels import Kernel |
| 11 | +from gpytorch.kernels.matern_kernel import MaternKernel |
| 12 | +from gpytorch.priors import Prior |
| 13 | +from gpytorch.priors.torch_priors import GammaPrior |
| 14 | + |
| 15 | + |
| 16 | +class LinearTruncatedFidelityKernel(Kernel): |
| 17 | + r""" |
| 18 | + Computes a covariance matrix based on the Linear truncated kernel |
| 19 | + between inputs :math:`\mathbf{x_1}` and :math:`\mathbf{x_2}`: |
| 20 | +
|
| 21 | + .. math:: |
| 22 | +
|
| 23 | + \begin{equation*} |
| 24 | + k_{\text{LinearTruncated}}(\mathbf{x_1}, \mathbf{x_2}) = k_0 |
| 25 | + + c_1(\mathbf{x_1},\mathbf{x_1})k_1 + c_2(\mathbf{x_1},\mathbf{x_2})k_2 |
| 26 | + + c_3(\mathbf{x_2},\mathbf{x_2})k_3 |
| 27 | + \end{equation*} |
| 28 | +
|
| 29 | + where |
| 30 | +
|
| 31 | + * :math:`k_i(i=0,1,2,3)` are Matern kernels calculated between `\mathbf{x_1}[:-2]` |
| 32 | + and `\mathbf{x_2}[:-2]` with different priors. |
| 33 | + * :math:`c_1=(1-\mathbf{x_1}[-1])(1-\mathbf{x_2}[-1]))(1+\mathbf{x_1}[-1] |
| 34 | + \mathbf{x_2}[-1])^p` is the kernel of bias term, which can be decomposed |
| 35 | + into a determistic part and a polynomial kernel. |
| 36 | + :math:`c_3` is the same as `c_1` but is calculated from the second last entries |
| 37 | + of `\mathbf{x_1}` and `\mathbf{x_2}`. |
| 38 | + :math:`c_2` is the interaction term with four deterministic terms and the |
| 39 | + polynomial kernel between `\mathbf{x_1}[-2:]` and `\mathbf{x_2}[-2:]` |
| 40 | + * :math:`p` is the order of the polynomial kernel. |
| 41 | +
|
| 42 | + .. note:: |
| 43 | +
|
| 44 | + We assume the last two dimensions of input `x` are the fidelity parameters. |
| 45 | +
|
| 46 | + Args: |
| 47 | + :attr:`nu` (float): |
| 48 | + The smoothness parameter fo Matern kernel: either 1/2, 3/2, or 5/2. |
| 49 | + Default: '2.5' |
| 50 | + :attr:`batch_shape` (torch.Size, optional): |
| 51 | + Set this if you want a separate lengthscale for each |
| 52 | + batch of input data. It should be `b` if :attr:`x1` is a |
| 53 | + `b x n x d` tensor. Default: `torch.Size([])` |
| 54 | + :attr:`active_dims` (tuple of ints, optional): |
| 55 | + Set this if you want to |
| 56 | + compute the covariance of only a few input dimensions. The ints |
| 57 | + corresponds to the indices of the dimensions. Default: `None`. |
| 58 | + :attr:`lengthscale_prior` (Prior, optional): |
| 59 | + Set this if you want to apply a prior to the lengthscale parameter |
| 60 | + of Matern kernel `k_0`. Default: `Gamma(1.1, 1/20)` |
| 61 | + :attr:`lengthscale_constraint` (Constraint, optional): |
| 62 | + Set this if you want to apply a constraint to the lengthscale parameter |
| 63 | + of Matern kernel `k_0`. Default: `Positive` |
| 64 | + :attr:`lengthscale_2_prior` (Prior, optional): |
| 65 | + Set this if you want to apply a prior to the lengthscale parameter |
| 66 | + of Matern kernel `k_i(i>0)`. Default: `Gamma(5, 1/20)` |
| 67 | + :attr:`lengthscale_2_constraint` (Constraint, optional): |
| 68 | + Set this if you want to apply a constraint to the lengthscale parameter |
| 69 | + of Matern kernel `k_i(i>0)`. Default: `Positive` |
| 70 | + :attr:`power_prior` (Prior, optional): |
| 71 | + Set this if you want to apply a prior to the power parameter of |
| 72 | + polynomial kernel. Default: `None` |
| 73 | + :attr:`power_constraint` (Constraint, optional): |
| 74 | + Set this if you want to apply a constraint to the power parameter |
| 75 | + polynomial kernel. Default: `Positive` |
| 76 | +
|
| 77 | + Attributes: |
| 78 | + :attr:`lengthscale` (Tensor): |
| 79 | + The lengthscale parameter. Size/shape of parameter. |
| 80 | +
|
| 81 | + Example: |
| 82 | + >>> x = torch.randn(10, 5) |
| 83 | + >>> # Non-batch: Simple option |
| 84 | + >>> covar_module = LinearTruncatedFidelityKernel() |
| 85 | + >>> covar = covar_module(x) # Output: LazyVariable of size (10 x 10) |
| 86 | + >>> |
| 87 | + >>> batch_x = torch.randn(2, 10, 5) |
| 88 | + >>> # Batch: Simple option |
| 89 | + >>> covar_module = LinearTruncatedFidelityKernel(batch_shape = torch.Size([2])) |
| 90 | + >>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10) |
| 91 | + """ |
| 92 | + |
| 93 | + def __init__( |
| 94 | + self, |
| 95 | + nu: float = 2.5, |
| 96 | + train_iteration_fidelity: bool = True, |
| 97 | + train_data_fidelity: bool = True, |
| 98 | + lengthscale_prior: Optional[Prior] = None, |
| 99 | + power_prior: Optional[Prior] = None, |
| 100 | + power_constraint: Optional[Interval] = None, |
| 101 | + lengthscale_2_prior: Optional[Prior] = None, |
| 102 | + lengthscale_2_constraint: Optional[Interval] = None, |
| 103 | + **kwargs: Any, |
| 104 | + ): |
| 105 | + if not train_iteration_fidelity and not train_data_fidelity: |
| 106 | + raise UnsupportedError("You should have at least one fidelity parameter.") |
| 107 | + if nu not in {0.5, 1.5, 2.5}: |
| 108 | + raise ValueError("nu expected to be 0.5, 1.5, or 2.5") |
| 109 | + super().__init__(has_lengthscale=True, **kwargs) |
| 110 | + self.train_iteration_fidelity = train_iteration_fidelity |
| 111 | + self.train_data_fidelity = train_data_fidelity |
| 112 | + if power_constraint is None: |
| 113 | + power_constraint = Positive() |
| 114 | + |
| 115 | + if lengthscale_prior is None: |
| 116 | + self.lengthscale_prior = GammaPrior(1.1, 1 / 20) |
| 117 | + else: |
| 118 | + self.lengthscale_prior = lengthscale_prior |
| 119 | + |
| 120 | + if lengthscale_2_prior is None: |
| 121 | + self.lengthscale_2_prior = GammaPrior(5, 1 / 20) |
| 122 | + else: |
| 123 | + self.register_prior( |
| 124 | + "lengthscale_2_prior", |
| 125 | + lengthscale_2_prior, |
| 126 | + lambda: self.lengthscale_2, |
| 127 | + lambda v: self._set_lengthscale_2(v), |
| 128 | + ) |
| 129 | + |
| 130 | + if lengthscale_2_constraint is None: |
| 131 | + lengthscale_2_constraint = Positive() |
| 132 | + |
| 133 | + self.register_parameter( |
| 134 | + name="raw_power", |
| 135 | + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), |
| 136 | + ) |
| 137 | + self.register_parameter( |
| 138 | + name="raw_lengthscale_2", |
| 139 | + parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)), |
| 140 | + ) |
| 141 | + |
| 142 | + if power_prior is not None: |
| 143 | + self.register_prior( |
| 144 | + "power_prior", |
| 145 | + power_prior, |
| 146 | + lambda: self.power, |
| 147 | + lambda v: self._set_power(v), |
| 148 | + ) |
| 149 | + self.nu = nu |
| 150 | + self.register_constraint("raw_lengthscale_2", lengthscale_2_constraint) |
| 151 | + self.register_constraint("raw_power", power_constraint) |
| 152 | + |
| 153 | + @property |
| 154 | + def power(self) -> torch.Tensor: |
| 155 | + return self.raw_power_constraint.transform(self.raw_power) |
| 156 | + |
| 157 | + @power.setter |
| 158 | + def power(self, value: torch.Tensor) -> None: |
| 159 | + self._set_power(value) |
| 160 | + |
| 161 | + def _set_power(self, value: torch.Tensor) -> None: |
| 162 | + if not torch.is_tensor(value): |
| 163 | + value = torch.as_tensor(value).to(self.raw_power) |
| 164 | + self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value)) |
| 165 | + |
| 166 | + @property |
| 167 | + def lengthscale_2(self) -> torch.Tensor: |
| 168 | + return self.raw_lengthscale_2_constraint.transform(self.raw_lengthscale_2) |
| 169 | + |
| 170 | + @lengthscale_2.setter |
| 171 | + def lengthscale_2(self, value: torch.Tensor) -> None: |
| 172 | + self._set_lengthscale_2(value) |
| 173 | + |
| 174 | + def _set_lengthscale_2(self, value: torch.Tensor) -> None: |
| 175 | + if not torch.is_tensor(value): |
| 176 | + value = torch.as_tensor(value).to(self.raw_lengthscale_2) |
| 177 | + self.initialize( |
| 178 | + raw_lengthscale_2=self.raw_lengthscale_2_constraint.inverse_transform(value) |
| 179 | + ) |
| 180 | + |
| 181 | + def forward(self, x1: torch.Tensor, x2: torch.Tensor, **params) -> torch.Tensor: |
| 182 | + m = self.train_iteration_fidelity + self.train_data_fidelity |
| 183 | + power = self.power.view(*self.batch_shape, 1, 1) |
| 184 | + active_dimsM = list(range(x1.size()[-1] - m)) |
| 185 | + covar_module_1 = MaternKernel( |
| 186 | + nu=self.nu, |
| 187 | + batch_shape=self.batch_shape, |
| 188 | + lengthscale_prior=self.lengthscale_prior, |
| 189 | + active_dims=active_dimsM, |
| 190 | + ard_num_dims=x1.shape[-1] - m, |
| 191 | + ) |
| 192 | + covar_module_2 = MaternKernel( |
| 193 | + nu=self.nu, |
| 194 | + batch_shape=self.batch_shape, |
| 195 | + lengthscale_prior=self.lengthscale_2_prior, |
| 196 | + active_dims=active_dimsM, |
| 197 | + ard_num_dims=x1.shape[-1] - m, |
| 198 | + ) |
| 199 | + covar_0 = covar_module_1(x1, x2) |
| 200 | + x11_ = x1[..., -1].unsqueeze(-1) |
| 201 | + x21t_ = x2[..., -1].unsqueeze(-1).transpose(-1, -2) |
| 202 | + covar_1 = covar_module_2(x1, x2) |
| 203 | + if self.train_iteration_fidelity and self.train_data_fidelity: |
| 204 | + covar_2 = covar_module_2(x1, x2) |
| 205 | + covar_3 = covar_module_2(x1, x2) |
| 206 | + x12_ = x1[..., -2].unsqueeze(-1) |
| 207 | + x22t_ = x2[..., -2].unsqueeze(-1).transpose(-1, -2) |
| 208 | + res = ( |
| 209 | + covar_0 |
| 210 | + + (1 - x12_) * (1 - x22t_) * (1 + x12_ * x22t_).pow(power) * covar_1 |
| 211 | + + (1 - x12_) |
| 212 | + * (1 - x22t_) |
| 213 | + * (1 - x11_) |
| 214 | + * (1 - x21t_) |
| 215 | + * (1 + torch.matmul(x1[..., -2:], x2[..., -2:].transpose(-1, -2))).pow( |
| 216 | + power |
| 217 | + ) |
| 218 | + * covar_2 |
| 219 | + + (1 - x11_) * (1 - x21t_) * (1 + x11_ * x21t_).pow(power) * covar_3 |
| 220 | + ) |
| 221 | + else: |
| 222 | + res = ( |
| 223 | + covar_0 |
| 224 | + + (1 - x11_) * (1 - x21t_) * (1 + x11_ * x21t_).pow(power) * covar_1 |
| 225 | + ) |
| 226 | + return res |
0 commit comments