|
8 | 8 | Gaussian Process Regression models based on GPyTorch models. |
9 | 9 | """ |
10 | 10 |
|
11 | | -from typing import Any, Optional, Union |
| 11 | +from typing import Any, List, Optional, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | from gpytorch.constraints.constraints import GreaterThan |
@@ -117,8 +117,15 @@ def __init__( |
117 | 117 | batch_shape=self._aug_batch_shape, |
118 | 118 | outputscale_prior=GammaPrior(2.0, 0.15), |
119 | 119 | ) |
| 120 | + self._subset_batch_dict = { |
| 121 | + "likelihood.noise_covar.raw_noise": -2, |
| 122 | + "mean_module.constant": -2, |
| 123 | + "covar_module.raw_outputscale": -1, |
| 124 | + "covar_module.base_kernel.raw_lengthscale": -3, |
| 125 | + } |
120 | 126 | else: |
121 | 127 | self.covar_module = covar_module |
| 128 | + # TODO: Allow subsetting of other covar modules |
122 | 129 | if outcome_transform is not None: |
123 | 130 | self.outcome_transform = outcome_transform |
124 | 131 | self.to(train_X) |
@@ -192,6 +199,11 @@ def __init__( |
192 | 199 | ) |
193 | 200 | if outcome_transform is not None: |
194 | 201 | self.outcome_transform = outcome_transform |
| 202 | + self._subset_batch_dict = { |
| 203 | + "mean_module.constant": -2, |
| 204 | + "covar_module.raw_outputscale": -1, |
| 205 | + "covar_module.base_kernel.raw_lengthscale": -3, |
| 206 | + } |
195 | 207 | self.to(train_X) |
196 | 208 |
|
197 | 209 | def fantasize( |
@@ -242,6 +254,21 @@ def forward(self, x: Tensor) -> MultivariateNormal: |
242 | 254 | covar_x = self.covar_module(x) |
243 | 255 | return MultivariateNormal(mean_x, covar_x) |
244 | 256 |
|
| 257 | + def subset_output(self, idcs: List[int]) -> "BatchedMultiOutputGPyTorchModel": |
| 258 | + r"""Subset the model along the output dimension. |
| 259 | +
|
| 260 | + Args: |
| 261 | + idcs: The output indices to subset the model to. |
| 262 | +
|
| 263 | + Returns: |
| 264 | + The current model, subset to the specified output indices. |
| 265 | + """ |
| 266 | + new_model = super().subset_output(idcs=idcs) |
| 267 | + full_noise = new_model.likelihood.noise_covar.noise |
| 268 | + new_noise = full_noise[..., idcs if len(idcs) > 1 else idcs[0], :] |
| 269 | + new_model.likelihood.noise_covar.noise = new_noise |
| 270 | + return new_model |
| 271 | + |
245 | 272 |
|
246 | 273 | class HeteroskedasticSingleTaskGP(SingleTaskGP): |
247 | 274 | r"""A single-task exact GP model using a heteroskeastic noise model. |
@@ -311,3 +338,6 @@ def condition_on_observations( |
311 | 338 | self, X: Tensor, Y: Tensor, **kwargs: Any |
312 | 339 | ) -> "HeteroskedasticSingleTaskGP": |
313 | 340 | raise NotImplementedError |
| 341 | + |
| 342 | + def subset_output(self, idcs: List[int]) -> "HeteroskedasticSingleTaskGP": |
| 343 | + raise NotImplementedError |
0 commit comments