Skip to content

Commit 7258017

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Non-uniform model weights in EnsembleModel & EnsemblePosterior (#2993)
Summary: Pull Request resolved: #2993 Pull Request resolved: #2993 Adds support for non-uniform model weights in EnsembleModel & EnsemblePosterior. Why bother? - Setting weights to be non-uniform for `MatheronPathModel` (e.g., [0, 0, 0, 1]) for different models in an ensemble (e.g. a `FullyBayesianSingleTaskGP`) allows the drawing of function samples from fully Bayesian Models. - Also implemented modified sampling, since the rsample() method of the ensemble posterior would not be valid with batch shapes otherwise. Sampling occurs along the ensemble dimension only. Indended use in benchmarking. What this diff does not do: Create a batchable rsample for EnsemblePosteriors. The current implementation does not work as intended, as it does not sample exclusively over the ensemble dimension and cannot handle different ensemble weights for different batches. Instead, it samples one model in the ensemble and applies it across batches. Attempting to implement this was very cumbersome, and broke a number of tests across the stack. Reviewed By: Balandat Differential Revision: D80728012 fbshipit-source-id: 6a2316862d98f0275a199b9b1442040b675337ef
1 parent 290f43b commit 7258017

File tree

6 files changed

+325
-47
lines changed

6 files changed

+325
-47
lines changed

botorch/models/ensemble.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@
2424
class EnsembleModel(Model, ABC):
2525
"""Abstract base class for ensemble models."""
2626

27+
def __init__(self, weights: Tensor | None = None):
28+
"""Initialize the ensemble model.
29+
30+
Args:
31+
weights: Optional weights for the ensemble members.
32+
If None, the model weights will default to uniform in the
33+
corresponding mixture posterior.
34+
"""
35+
super().__init__()
36+
# buffer `weights` is generally a name occupied by another module,
37+
# so we have to be more specific here
38+
self.ensemble_weights = weights
39+
2740
@abstractmethod
2841
def forward(self, X: Tensor) -> Tensor:
2942
r"""Compute the (ensemble) model output at X.
@@ -82,7 +95,7 @@ def posterior(
8295
values, _ = self.outcome_transform.untransform(values, X=X)
8396
if output_indices is not None:
8497
values = values[..., output_indices]
85-
posterior = EnsemblePosterior(values=values)
98+
posterior = EnsemblePosterior(values=values, weights=self.ensemble_weights)
8699
if posterior_transform is not None:
87100
return posterior_transform(posterior)
88101
else:

botorch/posteriors/ensemble.py

Lines changed: 155 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,91 @@
1313
import torch
1414
from botorch.posteriors.posterior import Posterior
1515
from torch import Tensor
16+
from torch.distributions.multinomial import Multinomial
1617

1718

1819
class EnsemblePosterior(Posterior):
1920
r"""Ensemble posterior, that should be used for ensemble models that compute
2021
eagerly a finite number of samples per X value as for example a deep ensemble
2122
or a random forest."""
2223

23-
def __init__(self, values: Tensor) -> None:
24+
def __init__(self, values: Tensor, weights: Tensor | None = None) -> None:
2425
r"""
2526
Args:
2627
values: Values of the samples produced by this posterior as
2728
a `(b) x s x q x m` tensor where `m` is the output size of the
2829
model and `s` is the ensemble size.
30+
weights: Optional weights for the ensemble members as a tensor of shape
31+
`(s,)`. If None, uses uniform weights.
2932
"""
3033
if values.ndim < 3:
3134
raise ValueError("Values has to be at least three-dimensional.")
3235
self.values = values
36+
self._weights = weights.to(values) if weights is not None else None
37+
# Pre-compute normalized weights and mixture properties for efficiency
38+
self._mixture_dims = list(range(self.values.ndim - 2))
39+
self._normalized_weights = self._compute_normalized_weights()
40+
self._normalized_mixture_weights = self._compute_normalized_mixture_weights()
3341

3442
@property
3543
def ensemble_size(self) -> int:
3644
r"""The size of the ensemble"""
3745
return self.values.shape[-3]
3846

47+
@property
48+
def mixture_size(self) -> int:
49+
r"""The total number of elements in the mixture dimensions"""
50+
return self.values.shape[:-2].numel()
51+
52+
def _compute_normalized_weights(self) -> Tensor:
53+
r"""Compute and cache normalized weights."""
54+
if self._weights is not None:
55+
return self._weights / self._weights.sum(dim=-1, keepdim=True)
56+
else:
57+
return (
58+
torch.ones(
59+
self.ensemble_size,
60+
dtype=self.dtype,
61+
device=self.device,
62+
)
63+
/ self.ensemble_size
64+
)
65+
66+
def _compute_normalized_mixture_weights(self) -> Tensor:
67+
r"""Compute and cache normalized mixture weights."""
68+
if self._weights is not None:
69+
unnorm_weights = self._weights.expand(self.values.shape[:-2])
70+
return unnorm_weights / unnorm_weights.sum(
71+
dim=self._mixture_dims, keepdim=True
72+
)
73+
else:
74+
return (
75+
torch.ones(
76+
self.values.shape[:-2],
77+
dtype=self.dtype,
78+
device=self.device,
79+
)
80+
/ self.mixture_size
81+
)
82+
3983
@property
4084
def weights(self) -> Tensor:
4185
r"""The weights of the individual models in the ensemble.
42-
Equally weighted by default."""
43-
return torch.ones(self.ensemble_size) / self.ensemble_size
86+
uniformly weighted by default."""
87+
return self._normalized_weights
88+
89+
@property
90+
def mixture_weights(self) -> Tensor:
91+
r"""The weights of the individual models in the ensemble.
92+
uniformly weighted by default, and normalized over ensemble and
93+
batch dimensions of the model."""
94+
return self._normalized_mixture_weights
95+
96+
@property
97+
def mixture_dims(self) -> list[int]:
98+
r"""The mixture dimensions of the posterior. For ensemble posteriors,
99+
this includes all dimensions except the last two (query points and outputs)."""
100+
return self._mixture_dims
44101

45102
@property
46103
def device(self) -> torch.device:
@@ -55,17 +112,60 @@ def dtype(self) -> torch.dtype:
55112
@property
56113
def mean(self) -> Tensor:
57114
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
58-
return self.values.mean(dim=-3)
115+
# Weighted average across ensemble dimension
116+
return (self.values * self.weights[..., None, None]).sum(dim=-3)
59117

60118
@property
61119
def variance(self) -> Tensor:
62120
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
63121
64-
Computed as the sample variance across the ensemble outputs.
122+
Computed as the weighted sample variance across the ensemble outputs.
123+
124+
This treats weights as probability weights (normalized to sum to 1) and
125+
computes the unbiased weighted sample variance using the formula:
126+
Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²)
127+
where the sum over w_i² is taken over the ensemble dimension only.
128+
Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under
129+
"Reliability Weights".
65130
"""
66131
if self.ensemble_size == 1:
67132
return torch.zeros_like(self.values.squeeze(-3))
68-
return self.values.var(dim=-3)
133+
134+
# Add dimensions for query points and outputs to enable broadcasting
135+
weights = self.weights[..., None, None]
136+
squared_deviations = (self.values - self.mean.unsqueeze(-3)) ** 2
137+
return (weights * squared_deviations).sum(dim=-3) / (1 - (weights**2).sum())
138+
139+
@property
140+
def mixture_mean(self) -> Tensor:
141+
r"""The mixture mean of the posterior as a `(b) x n x m`-dim Tensor.
142+
143+
Computed as the weighted average across the ensemble outputs.
144+
"""
145+
return (self.values * self.mixture_weights[..., None, None]).sum(
146+
dim=self.mixture_dims
147+
)
148+
149+
@property
150+
def mixture_variance(self) -> Tensor:
151+
r"""The mixture variance of the posterior as a `(b) x n x m`-dim Tensor.
152+
153+
Computed as the weighted sample variance across the ensemble outputs.
154+
155+
This treats weights as probability weights (normalized to sum to 1) and
156+
computes the unbiased weighted sample variance using the formula:
157+
Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²) where w_i is normalized over the
158+
entire mixture, and the sum over w_i² is taken over all mixture dimensions.
159+
Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under
160+
"Reliability Weights".
161+
"""
162+
163+
# Add dimensions for query points and outputs to enable broadcasting
164+
weights = self.mixture_weights[..., None, None]
165+
squared_deviations = (self.values - self.mixture_mean.unsqueeze(-3)) ** 2
166+
return (weights * squared_deviations).sum(dim=self.mixture_dims) / (
167+
1 - (weights**2).sum()
168+
)
69169

70170
def _extended_shape(
71171
self,
@@ -76,6 +176,10 @@ def _extended_shape(
76176
"""
77177
return sample_shape + self.values.shape[:-3] + self.values.shape[-2:]
78178

179+
@property
180+
def batch_shape(self) -> torch.Size:
181+
return self.values.shape[:-3]
182+
79183
def rsample(
80184
self,
81185
sample_shape: torch.Size | None = None,
@@ -94,17 +198,26 @@ def rsample(
94198
Samples from the posterior, a tensor of shape
95199
`self._extended_shape(sample_shape=sample_shape)`.
96200
"""
97-
if sample_shape is None:
201+
if sample_shape is None or len(sample_shape) == 0:
98202
sample_shape = torch.Size([1])
99-
# get indices as base_samples
203+
204+
# NOTE This occasionally happens in Hypervolume evals when there
205+
# are no points which improve over the reference point. In this case, we
206+
# create a posterior for all the points which improve over the reference,
207+
# which is an empty set.
208+
if self.values.numel() == 0:
209+
return torch.empty(
210+
*self._extended_shape(sample_shape=sample_shape),
211+
device=self.device,
212+
dtype=self.dtype,
213+
)
214+
100215
base_samples = (
101-
torch.multinomial(
102-
self.weights,
103-
num_samples=sample_shape.numel(),
104-
replacement=True,
216+
Multinomial(
217+
probs=self.mixture_weights,
105218
)
106-
.reshape(sample_shape)
107-
.to(device=self.device)
219+
.sample(sample_shape=sample_shape)
220+
.argmax(dim=-1)
108221
)
109222
return self.rsample_from_base_samples(
110223
sample_shape=sample_shape, base_samples=base_samples
@@ -132,9 +245,31 @@ def rsample_from_base_samples(
132245
Samples from the posterior, a tensor of shape
133246
`self._extended_shape(sample_shape=sample_shape)`.
134247
"""
135-
if base_samples.shape != sample_shape:
136-
raise ValueError("Base samples do not match sample shape.")
137-
# move sample axis to front
138-
values = self.values.movedim(-3, 0)
139-
# sample from the first dimension of values
140-
return values[base_samples, ...]
248+
# Check that the first dimensions of base_samples match sample_shape
249+
if base_samples.shape != sample_shape + self.batch_shape:
250+
raise ValueError(
251+
f"Sample_shape={sample_shape + self.batch_shape} does not match "
252+
f"the leading dimensions of base_samples.shape={base_samples.shape}."
253+
)
254+
255+
if self.batch_shape:
256+
# Values is always going to be 4-dimensional with this reshape,
257+
# even if we have more than one batch dimension
258+
values = self.values.reshape(
259+
((self.batch_shape.numel(),) + self.values.shape[-3:])
260+
)
261+
262+
# Collapse the base samples to enable index selecting along the
263+
# ensemble dim (dim -3)
264+
batch_numel = self.batch_shape.numel()
265+
collapsed_base_samples = base_samples.reshape(sample_shape + (batch_numel,))
266+
267+
# First dimension is just 1, 2, 3, ..., batch_shape.numel() -1 to flatten
268+
# the first dimension and extract one index
269+
270+
# second dimension extracts the ensemble member, for each element in the
271+
# entire batch shape
272+
return values[torch.arange(batch_numel), collapsed_base_samples].reshape(
273+
self._extended_shape(sample_shape=sample_shape)
274+
)
275+
return self.values[base_samples]

botorch/sampling/index_sampler.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from botorch.posteriors.ensemble import EnsemblePosterior
1616
from botorch.sampling.base import MCSampler
1717
from torch import Tensor
18+
from torch.distributions.multinomial import Multinomial
1819

1920

2021
class IndexSampler(MCSampler):
@@ -44,14 +45,19 @@ def _construct_base_samples(self, posterior: EnsemblePosterior) -> None:
4445
posterior: The ensemble posterior to construct the base samples
4546
for.
4647
"""
47-
if self.base_samples is None or self.base_samples.shape != self.sample_shape:
48+
if (
49+
self.base_samples is None
50+
or self.base_samples.shape != self.sample_shape + posterior.batch_shape
51+
):
4852
with torch.random.fork_rng():
4953
torch.manual_seed(self.seed)
50-
base_samples = torch.multinomial(
51-
posterior.weights,
52-
num_samples=self.sample_shape.numel(),
53-
replacement=True,
54-
).reshape(self.sample_shape)
54+
base_samples = (
55+
Multinomial(
56+
probs=posterior.mixture_weights,
57+
)
58+
.sample(sample_shape=self.sample_shape)
59+
.argmax(dim=-1)
60+
)
5561
self.register_buffer("base_samples", base_samples)
5662
if self.base_samples.device != posterior.device:
5763
self.to(device=posterior.device) # pragma: nocover

test/models/test_ensemble.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
class DummyEnsembleModel(EnsembleModel):
1313
r"""A dummy ensemble model."""
1414

15-
def __init__(self):
15+
def __init__(self, weights=None):
1616
r"""Init model."""
17-
super().__init__()
17+
super().__init__(weights=weights)
1818
self._num_outputs = 2
1919
self.a = torch.rand(4, 3, 2)
2020

@@ -35,3 +35,19 @@ def test_DummyEnsembleModel(self):
3535
X = torch.randn(*shape)
3636
p = e.posterior(X)
3737
self.assertEqual(p.ensemble_size, 4)
38+
39+
def test_EnsembleModel_weights(self):
40+
"""Test that weights are properly passed from EnsembleModel to
41+
EnsemblePosterior."""
42+
custom_weights = torch.tensor([0.4, 0.3, 0.2, 0.1])
43+
e = DummyEnsembleModel(weights=custom_weights)
44+
45+
# Test weights are correctly passed through
46+
X = torch.randn(5, 3)
47+
p = e.posterior(X)
48+
self.assertAllClose(p.weights, custom_weights)
49+
50+
# Test with batch dimensions - weights should remain 1-dimensional
51+
X_batch = torch.randn(2, 5, 3) # batch_shape = (2,)
52+
p_batch = e.posterior(X_batch)
53+
self.assertAllClose(p_batch.weights, custom_weights)

0 commit comments

Comments
 (0)