13
13
import torch
14
14
from botorch .posteriors .posterior import Posterior
15
15
from torch import Tensor
16
+ from torch .distributions .multinomial import Multinomial
16
17
17
18
18
19
class EnsemblePosterior (Posterior ):
19
20
r"""Ensemble posterior, that should be used for ensemble models that compute
20
21
eagerly a finite number of samples per X value as for example a deep ensemble
21
22
or a random forest."""
22
23
23
- def __init__ (self , values : Tensor ) -> None :
24
+ def __init__ (self , values : Tensor , weights : Tensor | None = None ) -> None :
24
25
r"""
25
26
Args:
26
27
values: Values of the samples produced by this posterior as
27
28
a `(b) x s x q x m` tensor where `m` is the output size of the
28
29
model and `s` is the ensemble size.
30
+ weights: Optional weights for the ensemble members as a tensor of shape
31
+ `(s,)`. If None, uses uniform weights.
29
32
"""
30
33
if values .ndim < 3 :
31
34
raise ValueError ("Values has to be at least three-dimensional." )
32
35
self .values = values
36
+ self ._weights = weights .to (values ) if weights is not None else None
37
+ # Pre-compute normalized weights and mixture properties for efficiency
38
+ self ._mixture_dims = list (range (self .values .ndim - 2 ))
39
+ self ._normalized_weights = self ._compute_normalized_weights ()
40
+ self ._normalized_mixture_weights = self ._compute_normalized_mixture_weights ()
33
41
34
42
@property
35
43
def ensemble_size (self ) -> int :
36
44
r"""The size of the ensemble"""
37
45
return self .values .shape [- 3 ]
38
46
47
+ @property
48
+ def mixture_size (self ) -> int :
49
+ r"""The total number of elements in the mixture dimensions"""
50
+ return self .values .shape [:- 2 ].numel ()
51
+
52
+ def _compute_normalized_weights (self ) -> Tensor :
53
+ r"""Compute and cache normalized weights."""
54
+ if self ._weights is not None :
55
+ return self ._weights / self ._weights .sum (dim = - 1 , keepdim = True )
56
+ else :
57
+ return (
58
+ torch .ones (
59
+ self .ensemble_size ,
60
+ dtype = self .dtype ,
61
+ device = self .device ,
62
+ )
63
+ / self .ensemble_size
64
+ )
65
+
66
+ def _compute_normalized_mixture_weights (self ) -> Tensor :
67
+ r"""Compute and cache normalized mixture weights."""
68
+ if self ._weights is not None :
69
+ unnorm_weights = self ._weights .expand (self .values .shape [:- 2 ])
70
+ return unnorm_weights / unnorm_weights .sum (
71
+ dim = self ._mixture_dims , keepdim = True
72
+ )
73
+ else :
74
+ return (
75
+ torch .ones (
76
+ self .values .shape [:- 2 ],
77
+ dtype = self .dtype ,
78
+ device = self .device ,
79
+ )
80
+ / self .mixture_size
81
+ )
82
+
39
83
@property
40
84
def weights (self ) -> Tensor :
41
85
r"""The weights of the individual models in the ensemble.
42
- Equally weighted by default."""
43
- return torch .ones (self .ensemble_size ) / self .ensemble_size
86
+ uniformly weighted by default."""
87
+ return self ._normalized_weights
88
+
89
+ @property
90
+ def mixture_weights (self ) -> Tensor :
91
+ r"""The weights of the individual models in the ensemble.
92
+ uniformly weighted by default, and normalized over ensemble and
93
+ batch dimensions of the model."""
94
+ return self ._normalized_mixture_weights
95
+
96
+ @property
97
+ def mixture_dims (self ) -> list [int ]:
98
+ r"""The mixture dimensions of the posterior. For ensemble posteriors,
99
+ this includes all dimensions except the last two (query points and outputs)."""
100
+ return self ._mixture_dims
44
101
45
102
@property
46
103
def device (self ) -> torch .device :
@@ -55,17 +112,60 @@ def dtype(self) -> torch.dtype:
55
112
@property
56
113
def mean (self ) -> Tensor :
57
114
r"""The mean of the posterior as a `(b) x n x m`-dim Tensor."""
58
- return self .values .mean (dim = - 3 )
115
+ # Weighted average across ensemble dimension
116
+ return (self .values * self .weights [..., None , None ]).sum (dim = - 3 )
59
117
60
118
@property
61
119
def variance (self ) -> Tensor :
62
120
r"""The variance of the posterior as a `(b) x n x m`-dim Tensor.
63
121
64
- Computed as the sample variance across the ensemble outputs.
122
+ Computed as the weighted sample variance across the ensemble outputs.
123
+
124
+ This treats weights as probability weights (normalized to sum to 1) and
125
+ computes the unbiased weighted sample variance using the formula:
126
+ Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²)
127
+ where the sum over w_i² is taken over the ensemble dimension only.
128
+ Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under
129
+ "Reliability Weights".
65
130
"""
66
131
if self .ensemble_size == 1 :
67
132
return torch .zeros_like (self .values .squeeze (- 3 ))
68
- return self .values .var (dim = - 3 )
133
+
134
+ # Add dimensions for query points and outputs to enable broadcasting
135
+ weights = self .weights [..., None , None ]
136
+ squared_deviations = (self .values - self .mean .unsqueeze (- 3 )) ** 2
137
+ return (weights * squared_deviations ).sum (dim = - 3 ) / (1 - (weights ** 2 ).sum ())
138
+
139
+ @property
140
+ def mixture_mean (self ) -> Tensor :
141
+ r"""The mixture mean of the posterior as a `(b) x n x m`-dim Tensor.
142
+
143
+ Computed as the weighted average across the ensemble outputs.
144
+ """
145
+ return (self .values * self .mixture_weights [..., None , None ]).sum (
146
+ dim = self .mixture_dims
147
+ )
148
+
149
+ @property
150
+ def mixture_variance (self ) -> Tensor :
151
+ r"""The mixture variance of the posterior as a `(b) x n x m`-dim Tensor.
152
+
153
+ Computed as the weighted sample variance across the ensemble outputs.
154
+
155
+ This treats weights as probability weights (normalized to sum to 1) and
156
+ computes the unbiased weighted sample variance using the formula:
157
+ Var = Σ(w_i * (x_i - μ)²) / (1 - Σw_i²) where w_i is normalized over the
158
+ entire mixture, and the sum over w_i² is taken over all mixture dimensions.
159
+ Source: https://en.wikipedia.org/wiki/Weighted_arithmetic_mean under
160
+ "Reliability Weights".
161
+ """
162
+
163
+ # Add dimensions for query points and outputs to enable broadcasting
164
+ weights = self .mixture_weights [..., None , None ]
165
+ squared_deviations = (self .values - self .mixture_mean .unsqueeze (- 3 )) ** 2
166
+ return (weights * squared_deviations ).sum (dim = self .mixture_dims ) / (
167
+ 1 - (weights ** 2 ).sum ()
168
+ )
69
169
70
170
def _extended_shape (
71
171
self ,
@@ -76,6 +176,10 @@ def _extended_shape(
76
176
"""
77
177
return sample_shape + self .values .shape [:- 3 ] + self .values .shape [- 2 :]
78
178
179
+ @property
180
+ def batch_shape (self ) -> torch .Size :
181
+ return self .values .shape [:- 3 ]
182
+
79
183
def rsample (
80
184
self ,
81
185
sample_shape : torch .Size | None = None ,
@@ -94,17 +198,26 @@ def rsample(
94
198
Samples from the posterior, a tensor of shape
95
199
`self._extended_shape(sample_shape=sample_shape)`.
96
200
"""
97
- if sample_shape is None :
201
+ if sample_shape is None or len ( sample_shape ) == 0 :
98
202
sample_shape = torch .Size ([1 ])
99
- # get indices as base_samples
203
+
204
+ # NOTE This occasionally happens in Hypervolume evals when there
205
+ # are no points which improve over the reference point. In this case, we
206
+ # create a posterior for all the points which improve over the reference,
207
+ # which is an empty set.
208
+ if self .values .numel () == 0 :
209
+ return torch .empty (
210
+ * self ._extended_shape (sample_shape = sample_shape ),
211
+ device = self .device ,
212
+ dtype = self .dtype ,
213
+ )
214
+
100
215
base_samples = (
101
- torch .multinomial (
102
- self .weights ,
103
- num_samples = sample_shape .numel (),
104
- replacement = True ,
216
+ Multinomial (
217
+ probs = self .mixture_weights ,
105
218
)
106
- .reshape ( sample_shape )
107
- .to ( device = self . device )
219
+ .sample ( sample_shape = sample_shape )
220
+ .argmax ( dim = - 1 )
108
221
)
109
222
return self .rsample_from_base_samples (
110
223
sample_shape = sample_shape , base_samples = base_samples
@@ -132,9 +245,31 @@ def rsample_from_base_samples(
132
245
Samples from the posterior, a tensor of shape
133
246
`self._extended_shape(sample_shape=sample_shape)`.
134
247
"""
135
- if base_samples .shape != sample_shape :
136
- raise ValueError ("Base samples do not match sample shape." )
137
- # move sample axis to front
138
- values = self .values .movedim (- 3 , 0 )
139
- # sample from the first dimension of values
140
- return values [base_samples , ...]
248
+ # Check that the first dimensions of base_samples match sample_shape
249
+ if base_samples .shape != sample_shape + self .batch_shape :
250
+ raise ValueError (
251
+ f"Sample_shape={ sample_shape + self .batch_shape } does not match "
252
+ f"the leading dimensions of base_samples.shape={ base_samples .shape } ."
253
+ )
254
+
255
+ if self .batch_shape :
256
+ # Values is always going to be 4-dimensional with this reshape,
257
+ # even if we have more than one batch dimension
258
+ values = self .values .reshape (
259
+ ((self .batch_shape .numel (),) + self .values .shape [- 3 :])
260
+ )
261
+
262
+ # Collapse the base samples to enable index selecting along the
263
+ # ensemble dim (dim -3)
264
+ batch_numel = self .batch_shape .numel ()
265
+ collapsed_base_samples = base_samples .reshape (sample_shape + (batch_numel ,))
266
+
267
+ # First dimension is just 1, 2, 3, ..., batch_shape.numel() -1 to flatten
268
+ # the first dimension and extract one index
269
+
270
+ # second dimension extracts the ensemble member, for each element in the
271
+ # entire batch shape
272
+ return values [torch .arange (batch_numel ), collapsed_base_samples ].reshape (
273
+ self ._extended_shape (sample_shape = sample_shape )
274
+ )
275
+ return self .values [base_samples ]
0 commit comments