|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +r""" |
| 8 | +.. [rahimi2007random] |
| 9 | + A. Rahimi and B. Recht. Random features for large-scale kernel machines. |
| 10 | + Advances in Neural Information Processing Systems 20 (2007). |
| 11 | +
|
| 12 | +.. [sutherland2015error] |
| 13 | + D. J. Sutherland and J. Schneider. On the error of random Fourier features. |
| 14 | + arXiv preprint arXiv:1506.02785 (2015). |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +from typing import Any, Callable |
| 20 | + |
| 21 | +import torch |
| 22 | +from botorch.exceptions.errors import UnsupportedError |
| 23 | +from botorch.sampling.pathwise.features.maps import KernelFeatureMap |
| 24 | +from botorch.sampling.pathwise.utils import ( |
| 25 | + ChainedTransform, |
| 26 | + FeatureSelector, |
| 27 | + InverseLengthscaleTransform, |
| 28 | + OutputscaleTransform, |
| 29 | + SineCosineTransform, |
| 30 | +) |
| 31 | +from botorch.utils.dispatcher import Dispatcher |
| 32 | +from botorch.utils.sampling import draw_sobol_normal_samples |
| 33 | +from gpytorch import kernels |
| 34 | +from gpytorch.kernels.kernel import Kernel |
| 35 | +from torch import Size, Tensor |
| 36 | +from torch.distributions import Gamma |
| 37 | + |
| 38 | +TKernelFeatureMapGenerator = Callable[[Kernel, int, int], KernelFeatureMap] |
| 39 | +GenKernelFeatures = Dispatcher("gen_kernel_features") |
| 40 | + |
| 41 | + |
| 42 | +def gen_kernel_features( |
| 43 | + kernel: kernels.Kernel, |
| 44 | + num_inputs: int, |
| 45 | + num_outputs: int, |
| 46 | + **kwargs: Any, |
| 47 | +) -> KernelFeatureMap: |
| 48 | + r"""Generates a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{n}` such that |
| 49 | + :math:`k(x, x') ≈ \phi(x)^{T} \phi(x')`. For stationary kernels :math:`k`, defaults |
| 50 | + to the method of random Fourier features. For more details, see [rahimi2007random]_ |
| 51 | + and [sutherland2015error]_. |
| 52 | +
|
| 53 | + Args: |
| 54 | + kernel: The kernel :math:`k` to be represented via a finite-dim basis. |
| 55 | + num_inputs: The number of input features. |
| 56 | + num_outputs: The number of kernel features. |
| 57 | + """ |
| 58 | + return GenKernelFeatures( |
| 59 | + kernel, |
| 60 | + num_inputs=num_inputs, |
| 61 | + num_outputs=num_outputs, |
| 62 | + **kwargs, |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def _gen_fourier_features( |
| 67 | + kernel: kernels.Kernel, |
| 68 | + weight_generator: Callable[[Size], Tensor], |
| 69 | + num_inputs: int, |
| 70 | + num_outputs: int, |
| 71 | +) -> KernelFeatureMap: |
| 72 | + r"""Generate a feature map :math:`\phi: \mathcal{X} \to \mathbb{R}^{2l}` that |
| 73 | + approximates a stationary kernel so that :math:`k(x, x') ≈ \phi(x)^\top \phi(x')`. |
| 74 | +
|
| 75 | + Following [sutherland2015error]_, we represent complex exponentials by pairs of |
| 76 | + basis functions :math:`\phi_{i}(x) = \sin(x^\top w_{i})` and |
| 77 | + :math:`\phi_{i + l} = \cos(x^\top w_{i}). |
| 78 | +
|
| 79 | + Args: |
| 80 | + kernel: A stationary kernel :math:`k(x, x') = k(x - x')`. |
| 81 | + weight_generator: A callable used to generate weight vectors :math:`w`. |
| 82 | + num_inputs: The number of input features. |
| 83 | + num_outputs: The number of Fourier features. |
| 84 | + """ |
| 85 | + if num_outputs % 2: |
| 86 | + raise UnsupportedError( |
| 87 | + f"Expected an even number of output features, but received {num_outputs=}." |
| 88 | + ) |
| 89 | + |
| 90 | + input_transform = InverseLengthscaleTransform(kernel) |
| 91 | + if kernel.active_dims is not None: |
| 92 | + num_inputs = len(kernel.active_dims) |
| 93 | + input_transform = ChainedTransform( |
| 94 | + input_transform, FeatureSelector(indices=kernel.active_dims) |
| 95 | + ) |
| 96 | + |
| 97 | + weight = weight_generator( |
| 98 | + Size([kernel.batch_shape.numel() * num_outputs // 2, num_inputs]) |
| 99 | + ).reshape(*kernel.batch_shape, num_outputs // 2, num_inputs) |
| 100 | + |
| 101 | + output_transform = SineCosineTransform( |
| 102 | + torch.tensor((2 / num_outputs) ** 0.5, device=kernel.device, dtype=kernel.dtype) |
| 103 | + ) |
| 104 | + return KernelFeatureMap( |
| 105 | + kernel=kernel, |
| 106 | + weight=weight, |
| 107 | + input_transform=input_transform, |
| 108 | + output_transform=output_transform, |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | +@GenKernelFeatures.register(kernels.RBFKernel) |
| 113 | +def _gen_kernel_features_rbf( |
| 114 | + kernel: kernels.RBFKernel, |
| 115 | + *, |
| 116 | + num_inputs: int, |
| 117 | + num_outputs: int, |
| 118 | +) -> KernelFeatureMap: |
| 119 | + def _weight_generator(shape: Size) -> Tensor: |
| 120 | + try: |
| 121 | + n, d = shape |
| 122 | + except ValueError: |
| 123 | + raise UnsupportedError( |
| 124 | + f"Expected `shape` to be 2-dimensional, but {len(shape)=}." |
| 125 | + ) |
| 126 | + |
| 127 | + return draw_sobol_normal_samples( |
| 128 | + n=n, |
| 129 | + d=d, |
| 130 | + device=kernel.lengthscale.device, |
| 131 | + dtype=kernel.lengthscale.dtype, |
| 132 | + ) |
| 133 | + |
| 134 | + return _gen_fourier_features( |
| 135 | + kernel=kernel, |
| 136 | + weight_generator=_weight_generator, |
| 137 | + num_inputs=num_inputs, |
| 138 | + num_outputs=num_outputs, |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +@GenKernelFeatures.register(kernels.MaternKernel) |
| 143 | +def _gen_kernel_features_matern( |
| 144 | + kernel: kernels.MaternKernel, |
| 145 | + *, |
| 146 | + num_inputs: int, |
| 147 | + num_outputs: int, |
| 148 | +) -> KernelFeatureMap: |
| 149 | + def _weight_generator(shape: Size) -> Tensor: |
| 150 | + try: |
| 151 | + n, d = shape |
| 152 | + except ValueError: |
| 153 | + raise UnsupportedError( |
| 154 | + f"Expected `shape` to be 2-dimensional, but {len(shape)=}." |
| 155 | + ) |
| 156 | + |
| 157 | + dtype = kernel.lengthscale.dtype |
| 158 | + device = kernel.lengthscale.device |
| 159 | + nu = torch.tensor(kernel.nu, device=device, dtype=dtype) |
| 160 | + normals = draw_sobol_normal_samples(n=n, d=d, device=device, dtype=dtype) |
| 161 | + return Gamma(nu, nu).rsample((n, 1)).rsqrt() * normals |
| 162 | + |
| 163 | + return _gen_fourier_features( |
| 164 | + kernel=kernel, |
| 165 | + weight_generator=_weight_generator, |
| 166 | + num_inputs=num_inputs, |
| 167 | + num_outputs=num_outputs, |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +@GenKernelFeatures.register(kernels.ScaleKernel) |
| 172 | +def _gen_kernel_features_scale( |
| 173 | + kernel: kernels.ScaleKernel, |
| 174 | + *, |
| 175 | + num_inputs: int, |
| 176 | + num_outputs: int, |
| 177 | +) -> KernelFeatureMap: |
| 178 | + active_dims = kernel.active_dims |
| 179 | + feature_map = gen_kernel_features( |
| 180 | + kernel.base_kernel, |
| 181 | + num_inputs=num_inputs if active_dims is None else len(active_dims), |
| 182 | + num_outputs=num_outputs, |
| 183 | + ) |
| 184 | + |
| 185 | + if active_dims is not None and active_dims is not kernel.base_kernel.active_dims: |
| 186 | + feature_map.input_transform = ChainedTransform( |
| 187 | + feature_map.input_transform, FeatureSelector(indices=active_dims) |
| 188 | + ) |
| 189 | + |
| 190 | + feature_map.output_transform = ChainedTransform( |
| 191 | + OutputscaleTransform(kernel), feature_map.output_transform |
| 192 | + ) |
| 193 | + return feature_map |
0 commit comments