@@ -66,6 +66,7 @@ def __init__(
6666 num_mv_samples : int = 10 ,
6767 num_y_samples : int = 128 ,
6868 use_gumbel : bool = True ,
69+ maximize : bool = True ,
6970 X_pending : Optional [Tensor ] = None ,
7071 ) -> None :
7172 r"""Single-outcome max-value entropy search acquisition function.
@@ -83,6 +84,7 @@ def __init__(
8384 use_gumbel: If True, use Gumbel approximation to sample the max values.
8485 X_pending: A `m x d`-dim Tensor of `m` design points that have been
8586 submitted for function evaluation but have not yet been evaluated.
87+ maximize: If True, consider the problem a maximization problem.
8688 """
8789 sampler = SobolQMCNormalSampler (num_y_samples )
8890 super ().__init__ (model = model , sampler = sampler )
@@ -108,6 +110,8 @@ def __init__(
108110 self .num_fantasies = num_fantasies
109111 self .use_gumbel = use_gumbel
110112 self .num_mv_samples = num_mv_samples
113+ self .maximize = maximize
114+ self .weight = 1.0 if maximize else - 1.0
111115
112116 # If we put the `self._sample_max_values()` to `set_X_pending()`,
113117 # it will throw errors when the initial `super().__init__()` is called,
@@ -166,11 +170,11 @@ def _sample_max_values(self):
166170 # sample max values
167171 if self .use_gumbel :
168172 self .posterior_max_values = _sample_max_value_Gumbel (
169- self .model , candidate_set , self .num_mv_samples
173+ self .model , candidate_set , self .num_mv_samples , self . maximize
170174 )
171175 else :
172176 self .posterior_max_values = _sample_max_value_Thompson (
173- self .model , candidate_set , self .num_mv_samples
177+ self .model , candidate_set , self .num_mv_samples , self . maximize
174178 )
175179
176180 @t_batch_mode_transform (expected_q = 1 )
@@ -186,7 +190,8 @@ def forward(self, X: Tensor) -> Tensor:
186190 """
187191 # Compute the posterior, posterior mean, variance and std
188192 posterior = self .model .posterior (X .unsqueeze (- 3 ), observation_noise = False )
189- mean = posterior .mean .squeeze (- 1 ).squeeze (- 1 ) # batch_shape x num_fantasies
193+ mean = self .weight * posterior .mean .squeeze (- 1 ).squeeze (- 1 )
194+ # batch_shape x num_fantasies
190195 variance = posterior .variance .clamp_min (CLAMP_LB ).view_as (mean )
191196 check_no_nans (mean )
192197 check_no_nans (variance )
@@ -228,14 +233,14 @@ def _compute_information_gain(
228233
229234 # compute the std_m, variance_m with noisy observation
230235 posterior_m = self .model .posterior (X .unsqueeze (- 3 ), observation_noise = True )
231- mean_m = posterior_m .mean .squeeze (- 1 )
236+ mean_m = self . weight * posterior_m .mean .squeeze (- 1 )
232237 # batch_shape x num_fantasies x (1 + num_trace_observations)
233238 variance_m = posterior_m .mvn .covariance_matrix
234239 # batch_shape x num_fantasies x (1 + num_trace_observations)^2
235240 check_no_nans (variance_m )
236241
237242 # compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
238- samples_m = self .sampler (posterior_m ).squeeze (- 1 )
243+ samples_m = self .weight * self . sampler (posterior_m ).squeeze (- 1 )
239244 # s_m x batch_shape x num_fantasies x (1 + num_trace_observations)
240245 L = torch .cholesky (variance_m )
241246 temp_term = torch .cholesky_solve (covar_mM .unsqueeze (- 1 ), L ).transpose (- 2 , - 1 )
@@ -279,7 +284,7 @@ def _compute_information_gain(
279284 # s_M x 1 x batch_shape x num_fantasies
280285
281286 # Compute log(p(ym | x, Dt))
282- log_pdf_fm = posterior_m .mvn .log_prob (samples_m ).unsqueeze (0 )
287+ log_pdf_fm = posterior_m .mvn .log_prob (self . weight * samples_m ).unsqueeze (0 )
283288 # 1 x s_m x batch_shape x num_fantasies
284289
285290 # H0 = H(ym | x, Dt)
@@ -331,6 +336,7 @@ def __init__(
331336 num_y_samples : int = 128 ,
332337 use_gumbel : bool = True ,
333338 X_pending : Optional [Tensor ] = None ,
339+ maximize : bool = True ,
334340 cost_aware_utility : Optional [CostAwareUtility ] = None ,
335341 project : Callable [[Tensor ], Tensor ] = lambda X : X ,
336342 expand : Callable [[Tensor ], Tensor ] = lambda X : X ,
@@ -353,6 +359,7 @@ def __init__(
353359 use_gumbel: If True, use Gumbel approximation to sample the max values.
354360 X_pending: A `m x d`-dim Tensor of `m` design points that have been
355361 submitted for function evaluation but have not yet been evaluated.
362+ maximize: If True, consider the problem a maximization problem.
356363 cost_aware_utility: A CostAwareUtility computing the cost-transformed
357364 utility from a candidate set and samples of increases in utility.
358365 project: A callable mapping a `batch_shape x q x d` tensor of design
@@ -372,6 +379,7 @@ def __init__(
372379 num_y_samples = num_y_samples ,
373380 X_pending = X_pending ,
374381 use_gumbel = use_gumbel ,
382+ maximize = maximize ,
375383 )
376384
377385 if cost_aware_utility is None :
@@ -382,8 +390,8 @@ def __init__(
382390 self .cost_aware_utility = cost_aware_utility
383391 self .expand = expand
384392 self .project = project
385- # @TODO make sure fidelity_dims in project, expand & cost_aware_utility align
386- # seems it is very difficult in the current way of handling project/expand
393+ # @TODO make sure fidelity_dims align in project, expand & cost_aware_utility
394+ # It seems very difficult due to the current way of handling project/expand
387395
388396 # resample max values after initializing self.project
389397 # so that the max value samples are at the highest fidelity
@@ -408,7 +416,7 @@ def forward(self, X: Tensor) -> Tensor:
408416 # Compute the posterior, posterior mean, variance without noise
409417 # `_m` and `_M` in the var names means the current and the max fidelity.
410418 posterior = self .model .posterior (X_all , observation_noise = False )
411- mean_M = posterior .mean [..., - 1 , 0 ] # batch_shape x num_fantasies
419+ mean_M = self . weight * posterior .mean [..., - 1 , 0 ] # batch_shape x num_fantasies
412420 variance_M = posterior .variance [..., - 1 , 0 ].clamp_min (CLAMP_LB )
413421 # get the covariance between the low fidelities and max fidelity
414422 covar_mM = posterior .mvn .covariance_matrix [..., :- 1 , - 1 ]
@@ -427,7 +435,7 @@ def forward(self, X: Tensor) -> Tensor:
427435
428436
429437def _sample_max_value_Thompson (
430- model : Model , candidate_set : Tensor , num_samples : int
438+ model : Model , candidate_set : Tensor , num_samples : int , maximize : bool = True
431439) -> Tensor :
432440 """Samples the max values by discrete Thompson sampling.
433441
@@ -436,14 +444,16 @@ def _sample_max_value_Thompson(
436444 Args:
437445 model: A fitted single-outcome model.
438446 candidate_set: A `n x d` Tensor including `n` candidate points to
439- discretize the design space
440- num_samples: Number of max value samples
447+ discretize the design space.
448+ num_samples: Number of max value samples.
449+ maximize: If True, consider the problem a maximization problem.
441450
442451 Returns:
443452 A `num_samples x num_fantasies` Tensor of max value samples
444453 """
445454 posterior = model .posterior (candidate_set )
446- samples = posterior .rsample (torch .Size ([num_samples ])).squeeze (- 1 )
455+ weight = 1.0 if maximize else - 1.0
456+ samples = weight * posterior .rsample (torch .Size ([num_samples ])).squeeze (- 1 )
447457 # samples is num_samples x (num_fantasies) x n
448458 max_values , _ = samples .max (dim = - 1 )
449459 if len (samples .shape ) == 2 :
@@ -453,7 +463,7 @@ def _sample_max_value_Thompson(
453463
454464
455465def _sample_max_value_Gumbel (
456- model : Model , candidate_set : Tensor , num_samples : int
466+ model : Model , candidate_set : Tensor , num_samples : int , maximize : bool = True
457467) -> Tensor :
458468 """Samples the max values by Gumbel approximation.
459469
@@ -462,15 +472,17 @@ def _sample_max_value_Gumbel(
462472 Args:
463473 model: A fitted single-outcome model.
464474 candidate_set: A `n x d` Tensor including `n` candidate points to
465- discretize the design space
466- num_samples: Number of max value samples
475+ discretize the design space.
476+ num_samples: Number of max value samples.
477+ maximize: If True, consider the problem a maximization problem.
467478
468479 Returns:
469480 A `num_samples x num_fantasies` Tensor of max value samples
470481 """
471482 # define the approximate CDF for the max value under the independence assumption
472483 posterior = model .posterior (candidate_set )
473- mu = posterior .mean
484+ weight = 1.0 if maximize else - 1.0
485+ mu = weight * posterior .mean
474486 sigma = posterior .variance .clamp_min (1e-8 ).sqrt ()
475487 # mu, sigma is (num_fantasies) X n X 1
476488 if len (mu .shape ) == 3 and mu .shape [- 1 ] == 1 :
0 commit comments