@@ -55,6 +55,7 @@ def __init__(
55
55
self ,
56
56
model : Model ,
57
57
posterior_transform : PosteriorTransform | None = None ,
58
+ allow_multi_output : bool = False ,
58
59
) -> None :
59
60
r"""Base constructor for analytic acquisition functions.
60
61
@@ -63,10 +64,12 @@ def __init__(
63
64
posterior_transform: A PosteriorTransform. If using a multi-output model,
64
65
a PosteriorTransform that transforms the multi-output posterior into a
65
66
single-output posterior is required.
67
+ allow_multi_output: If False, requires a posterior_transform if a
68
+ multi-output model is passed.
66
69
"""
67
70
super ().__init__ (model = model )
68
71
if posterior_transform is None :
69
- if model .num_outputs != 1 :
72
+ if not allow_multi_output and model .num_outputs != 1 :
70
73
raise UnsupportedError (
71
74
"Must specify a posterior transform when using a "
72
75
"multi-output model."
@@ -89,21 +92,21 @@ def _mean_and_sigma(
89
92
"""Computes the first and second moments of the model posterior.
90
93
91
94
Args:
92
- X: `batch_shape x q x d`-dim Tensor of model inputs .
95
+ X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points .
93
96
compute_sigma: Boolean indicating whether or not to compute the second
94
97
moment (default: True).
95
98
min_var: The minimum value the variance is clamped too. Should be positive.
96
99
97
100
Returns:
98
- A tuple of tensors containing the first and second moments of the model
99
- posterior. Removes the last two dimensions if they have size one. Only
100
- returns a single tensor of means if compute_sigma is True .
101
+ A tuple of tensors of shape `(b1 x ... x bk) x m` containing the first and
102
+ second moments of the model posterior, where `m` is the number of outputs.
103
+ Returns `None` instead of the second tensor if ` compute_sigma` is False .
101
104
"""
102
105
self .to (X ) # ensures buffers / parameters are on the same device and dtype
103
106
posterior = self .model .posterior (
104
107
X = X , posterior_transform = self .posterior_transform
105
108
)
106
- mean = posterior .mean .squeeze (- 2 ). squeeze ( - 1 ) # removing redundant dimensions
109
+ mean = posterior .mean .squeeze (- 2 ) # remove q-batch dimension
107
110
if not compute_sigma :
108
111
return mean , None
109
112
sigma = posterior .variance .clamp_min (min_var ).sqrt ().view (mean .shape )
@@ -168,9 +171,9 @@ def forward(self, X: Tensor) -> Tensor:
168
171
A `(b1 x ... bk)`-dim tensor of Log Probability of Improvement values at
169
172
the given design points `X`.
170
173
"""
171
- mean , sigma = self ._mean_and_sigma (X )
174
+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
172
175
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
173
- return log_Phi (u )
176
+ return log_Phi (u . squeeze ( - 1 ) )
174
177
175
178
176
179
class ProbabilityOfImprovement (AnalyticAcquisitionFunction ):
@@ -223,9 +226,9 @@ def forward(self, X: Tensor) -> Tensor:
223
226
A `(b1 x ... bk)`-dim tensor of Probability of Improvement values at the
224
227
given design points `X`.
225
228
"""
226
- mean , sigma = self ._mean_and_sigma (X )
229
+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
227
230
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
228
- return Phi (u )
231
+ return Phi (u . squeeze ( - 1 ) )
229
232
230
233
231
234
class qAnalyticProbabilityOfImprovement (AnalyticAcquisitionFunction ):
@@ -354,9 +357,9 @@ def forward(self, X: Tensor) -> Tensor:
354
357
A `(b1 x ... bk)`-dim tensor of Expected Improvement values at the
355
358
given design points `X`.
356
359
"""
357
- mean , sigma = self ._mean_and_sigma (X )
360
+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
358
361
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
359
- return sigma * _ei_helper (u )
362
+ return ( sigma * _ei_helper (u )). squeeze ( - 1 )
360
363
361
364
362
365
class LogExpectedImprovement (AnalyticAcquisitionFunction ):
@@ -418,9 +421,9 @@ def forward(self, X: Tensor) -> Tensor:
418
421
A `(b1 x ... bk)`-dim tensor of the logarithm of the Expected Improvement
419
422
values at the given design points `X`.
420
423
"""
421
- mean , sigma = self ._mean_and_sigma (X )
424
+ mean , sigma = self ._mean_and_sigma (X ) # `(b1 x ... bk) x 1`
422
425
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
423
- return _log_ei_helper (u ) + sigma .log ()
426
+ return ( _log_ei_helper (u ) + sigma .log ()). squeeze ( - 1 )
424
427
425
428
426
429
class ConstrainedAnalyticAcquisitionFunctionMixin (ABC ):
@@ -433,7 +436,7 @@ def __init__(
433
436
r"""Analytic Log Probability of Feasibility.
434
437
435
438
Args:
436
- model: A fitted multi-output model.
439
+ model: A fitted single- or multi-output model.
437
440
constraints: A dictionary of the form `{i: [lower, upper]}`, where
438
441
`i` is the output index, and `lower` and `upper` are lower and upper
439
442
bounds on that output (resp. interpreted as -Inf / Inf if None).
@@ -501,13 +504,11 @@ def _compute_log_prob_feas(
501
504
r"""Compute logarithm of the feasibility probability for each batch of X.
502
505
503
506
Args:
504
- X: A `(b) x 1 x d`-dim Tensor of `(b)` t-batches of `d`-dim design
505
- points each.
506
507
means: A `(b) x m`-dim Tensor of means.
507
508
sigmas: A `(b) x m`-dim Tensor of standard deviations.
508
509
509
510
Returns:
510
- A `b `-dim tensor of log feasibility probabilities
511
+ A `(b) `-dim tensor of log feasibility probabilities
511
512
512
513
Note: This function does case-work for upper bound, lower bound, and both-sided
513
514
bounds. Another way to do it would be to use 'inf' and -'inf' for the
@@ -567,7 +568,7 @@ def __init__(
567
568
r"""Analytic Log Constrained Expected Improvement.
568
569
569
570
Args:
570
- model: A fitted multi-output model.
571
+ model: A fitted single- or multi-output model.
571
572
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
572
573
the best feasible function value observed so far (assumed noiseless).
573
574
objective_index: The index of the objective.
@@ -576,8 +577,7 @@ def __init__(
576
577
bounds on that output (resp. interpreted as -Inf / Inf if None)
577
578
maximize: If True, consider the problem a maximization problem.
578
579
"""
579
- # Use AcquisitionFunction constructor to avoid check for posterior transform.
580
- AcquisitionFunction .__init__ (self , model = model )
580
+ super ().__init__ (model = model , allow_multi_output = True )
581
581
self .posterior_transform = None
582
582
self .maximize = maximize
583
583
self .objective_index = objective_index
@@ -641,13 +641,12 @@ def __init__(
641
641
r"""Analytic Log Probability of Feasibility.
642
642
643
643
Args:
644
- model: A fitted multi-output model.
644
+ model: A fitted single- or multi-output model.
645
645
constraints: A dictionary of the form `{i: [lower, upper]}`, where
646
646
`i` is the output index, and `lower` and `upper` are lower and upper
647
647
bounds on that output (resp. interpreted as -Inf / Inf if None)
648
648
"""
649
- # Use AcquisitionFunction constructor to avoid check for posterior transform.
650
- AcquisitionFunction .__init__ (self , model = model )
649
+ super ().__init__ (model = model , allow_multi_output = True )
651
650
self .posterior_transform = None
652
651
ConstrainedAnalyticAcquisitionFunctionMixin .__init__ (self , constraints )
653
652
@@ -708,7 +707,7 @@ def __init__(
708
707
r"""Analytic Constrained Expected Improvement.
709
708
710
709
Args:
711
- model: A fitted multi-output model.
710
+ model: A fitted single- or multi-output model.
712
711
best_f: Either a scalar or a `b`-dim Tensor (batch mode) representing
713
712
the best feasible function value observed so far (assumed noiseless).
714
713
objective_index: The index of the objective.
@@ -718,8 +717,7 @@ def __init__(
718
717
maximize: If True, consider the problem a maximization problem.
719
718
"""
720
719
legacy_ei_numerics_warning (legacy_name = type (self ).__name__ )
721
- # Use AcquisitionFunction constructor to avoid check for posterior transform.
722
- AcquisitionFunction .__init__ (self , model = model )
720
+ super ().__init__ (model = model , allow_multi_output = True )
723
721
self .posterior_transform = None
724
722
self .maximize = maximize
725
723
self .objective_index = objective_index
@@ -828,7 +826,9 @@ def forward(self, X: Tensor) -> Tensor:
828
826
the given design points `X`.
829
827
"""
830
828
# add batch dimension for broadcasting to fantasy models
829
+ # (b1 x ... x bk) x num_fantasies x 1
831
830
mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 ))
831
+ mean , sigma = mean .squeeze (- 1 ), sigma .squeeze (- 1 )
832
832
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
833
833
log_ei = _log_ei_helper (u ) + sigma .log ()
834
834
# this is mathematically - though not numerically - equivalent to log(mean(ei))
@@ -906,14 +906,16 @@ def forward(self, X: Tensor) -> Tensor:
906
906
r"""Evaluate Expected Improvement on the candidate set X.
907
907
908
908
Args:
909
- X: A `b1 x ... bk x 1 x d`-dim batched tensor of `d`-dim design points.
909
+ X: A `( b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
910
910
911
911
Returns:
912
- A `b1 x ... bk`-dim tensor of Noisy Expected Improvement values at
912
+ A `( b1 x ... bk) `-dim tensor of Noisy Expected Improvement values at
913
913
the given design points `X`.
914
914
"""
915
915
# add batch dimension for broadcasting to fantasy models
916
- mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 ))
916
+ # (b1 x ... x bk) x num_fantasies x 1
917
+ mean , sigma = self ._mean_and_sigma (X .unsqueeze (- 3 )) # (b1 x ... x bk) x m1 x 1
918
+ mean , sigma = mean .squeeze (- 1 ), sigma .squeeze (- 1 )
917
919
u = _scaled_improvement (mean , sigma , self .best_f , self .maximize )
918
920
return (sigma * _ei_helper (u )).mean (dim = - 1 )
919
921
@@ -970,8 +972,9 @@ def forward(self, X: Tensor) -> Tensor:
970
972
A `(b1 x ... bk)`-dim tensor of Upper Confidence Bound values at the
971
973
given design points `X`.
972
974
"""
973
- mean , sigma = self ._mean_and_sigma (X )
974
- return (mean if self .maximize else - mean ) + self .beta .sqrt () * sigma
975
+ mean , sigma = self ._mean_and_sigma (X ) # (b1 x ... x bk) x 1
976
+ ucb = (mean if self .maximize else - mean ) + self .beta .sqrt () * sigma
977
+ return ucb .squeeze (- 1 )
975
978
976
979
977
980
class PosteriorMean (AnalyticAcquisitionFunction ):
@@ -1020,8 +1023,10 @@ def forward(self, X: Tensor) -> Tensor:
1020
1023
A `(b1 x ... bk)`-dim tensor of Posterior Mean values at the
1021
1024
given design points `X`.
1022
1025
"""
1023
- mean , _ = self ._mean_and_sigma (X , compute_sigma = False )
1024
- return mean if self .maximize else - mean
1026
+ mean , _ = self ._mean_and_sigma (X , compute_sigma = False ) # (b1 x ... x bk) x 1
1027
+ if not self .maximize :
1028
+ mean = - mean
1029
+ return mean .squeeze (- 1 )
1025
1030
1026
1031
1027
1032
class ScalarizedPosteriorMean (AnalyticAcquisitionFunction ):
@@ -1056,14 +1061,16 @@ def forward(self, X: Tensor) -> Tensor:
1056
1061
r"""Evaluate the scalarized posterior mean on the candidate set X.
1057
1062
1058
1063
Args:
1059
- X: A `(b ) x q x d`-dim Tensor of `(b)` t-batches of `d`-dim design
1060
- points each.
1064
+ X: A `(b1 x ... x bk ) x q x d`-dim Tensor of `(b1 x ... x bk)`
1065
+ t-batches of `d`-dim design points each.
1061
1066
1062
1067
Returns:
1063
- A `(b )`-dim Tensor of Posterior Mean values at the given design
1064
- points `X`.
1068
+ A `(b1 x ... x bk )`-dim Tensor of Posterior Mean values at the given
1069
+ design points `X`.
1065
1070
"""
1066
- return self ._mean_and_sigma (X , compute_sigma = False )[0 ] @ self .weights
1071
+ # (b1 x ... x bk) x q x 1
1072
+ mean , _ = self ._mean_and_sigma (X , compute_sigma = False )
1073
+ return mean .squeeze (- 1 ) @ self .weights
1067
1074
1068
1075
1069
1076
class PosteriorStandardDeviation (AnalyticAcquisitionFunction ):
@@ -1131,7 +1138,9 @@ def forward(self, X: Tensor) -> Tensor:
1131
1138
given design points `X`.
1132
1139
"""
1133
1140
_ , std = self ._mean_and_sigma (X )
1134
- return std if self .maximize else - std
1141
+ if not self .maximize :
1142
+ std = - std
1143
+ return std .view (X .shape [:- 2 ])
1135
1144
1136
1145
1137
1146
# --------------- Helper functions for analytic acquisition functions. ---------------
0 commit comments