1010
1111from __future__ import annotations
1212
13- from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union
13+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
1414
1515import numpy as np
1616import torch
17+ from botorch .acquisition import AcquisitionFunction
18+ from botorch .generation .utils import _remove_fixed_features_from_optimization
1719from botorch .optim .parameter_constraints import (
1820 _arrayify ,
1921 make_scipy_bounds ,
2325from botorch .optim .utils import _filter_kwargs , columnwise_clamp , fix_features
2426from scipy .optimize import minimize
2527from torch import Tensor
26- from torch .nn import Module
2728from torch .optim import Optimizer
2829
2930
3031def gen_candidates_scipy (
3132 initial_conditions : Tensor ,
32- acquisition_function : Module ,
33+ acquisition_function : AcquisitionFunction ,
3334 lower_bounds : Optional [Union [float , Tensor ]] = None ,
3435 upper_bounds : Optional [Union [float , Tensor ]] = None ,
3536 inequality_constraints : Optional [List [Tuple [Tensor , Tensor , float ]]] = None ,
@@ -83,9 +84,45 @@ def gen_candidates_scipy(
8384 )
8485 """
8586 options = options or {}
87+
88+ # REDUCED is used indicate if we are optimizing over a reduced domain dimension
89+ # after considering fixed_features.
90+ # REDUCED mode if fixed_features is not None except for when fixed_features.values()
91+ # contains None and linear constraints are passed.
92+ REDUCED = fixed_features is not None
93+ if inequality_constraints or equality_constraints :
94+ REDUCED = REDUCED and (None not in fixed_features .values ())
95+
96+ if REDUCED :
97+ _no_fixed_features = _remove_fixed_features_from_optimization (
98+ fixed_features = fixed_features ,
99+ acquisition_function = acquisition_function ,
100+ initial_conditions = initial_conditions ,
101+ lower_bounds = lower_bounds ,
102+ upper_bounds = upper_bounds ,
103+ inequality_constraints = inequality_constraints ,
104+ equality_constraints = equality_constraints ,
105+ )
106+
107+ # call the routine with no fixed_features
108+ clamped_candidates , batch_acquisition = gen_candidates_scipy (
109+ initial_conditions = _no_fixed_features .initial_conditions ,
110+ acquisition_function = _no_fixed_features .acquisition_function ,
111+ lower_bounds = _no_fixed_features .lower_bounds ,
112+ upper_bounds = _no_fixed_features .upper_bounds ,
113+ inequality_constraints = _no_fixed_features .inequality_constraints ,
114+ equality_constraints = _no_fixed_features .equality_constraints ,
115+ options = options ,
116+ fixed_features = None ,
117+ )
118+ clamped_candidates = _no_fixed_features .acquisition_function ._construct_X_full (
119+ clamped_candidates
120+ )
121+ return clamped_candidates , batch_acquisition
122+
86123 clamped_candidates = columnwise_clamp (
87124 X = initial_conditions , lower = lower_bounds , upper = upper_bounds
88- ). requires_grad_ ( True )
125+ )
89126
90127 shapeX = clamped_candidates .shape
91128 x0 = _arrayify (clamped_candidates .view (- 1 ))
@@ -111,7 +148,7 @@ def f(x):
111148 .contiguous ()
112149 .requires_grad_ (True )
113150 )
114- X_fix = fix_features (X = X , fixed_features = fixed_features )
151+ X_fix = fix_features (X , fixed_features = fixed_features )
115152 loss = - acquisition_function (X_fix ).sum ()
116153 # compute gradient w.r.t. the inputs (does not accumulate in leaves)
117154 gradf = _arrayify (torch .autograd .grad (loss , X )[0 ].contiguous ().view (- 1 ))
@@ -137,20 +174,22 @@ def f(x):
137174 options = {k : v for k , v in options .items () if k not in ["method" , "callback" ]},
138175 )
139176 candidates = fix_features (
140- X = torch .from_numpy (res .x ).to (initial_conditions ).view (shapeX ). contiguous ( ),
177+ X = torch .from_numpy (res .x ).to (initial_conditions ).reshape (shapeX ),
141178 fixed_features = fixed_features ,
142179 )
180+
143181 clamped_candidates = columnwise_clamp (
144182 X = candidates , lower = lower_bounds , upper = upper_bounds , raise_on_violation = True
145183 )
146184 with torch .no_grad ():
147185 batch_acquisition = acquisition_function (clamped_candidates )
186+
148187 return clamped_candidates , batch_acquisition
149188
150189
151190def gen_candidates_torch (
152191 initial_conditions : Tensor ,
153- acquisition_function : Callable ,
192+ acquisition_function : AcquisitionFunction ,
154193 lower_bounds : Optional [Union [float , Tensor ]] = None ,
155194 upper_bounds : Optional [Union [float , Tensor ]] = None ,
156195 optimizer : Type [Optimizer ] = torch .optim .Adam ,
@@ -199,10 +238,41 @@ def gen_candidates_torch(
199238 )
200239 """
201240 options = options or {}
241+
242+ # REDUCED is used indicate if we are optimizing over a reduced domain dimension
243+ # after considering fixed_features.
244+ REDUCED = fixed_features is not None
245+
246+ if REDUCED :
247+ _no_fixed_features = _remove_fixed_features_from_optimization (
248+ fixed_features = fixed_features ,
249+ acquisition_function = acquisition_function ,
250+ initial_conditions = initial_conditions ,
251+ lower_bounds = lower_bounds ,
252+ upper_bounds = upper_bounds ,
253+ inequality_constraints = None ,
254+ equality_constraints = None ,
255+ )
256+
257+ # call the routine with no fixed_features
258+ clamped_candidates , batch_acquisition = gen_candidates_torch (
259+ initial_conditions = _no_fixed_features .initial_conditions ,
260+ acquisition_function = _no_fixed_features .acquisition_function ,
261+ lower_bounds = _no_fixed_features .lower_bounds ,
262+ upper_bounds = _no_fixed_features .upper_bounds ,
263+ optimizer = optimizer ,
264+ options = options ,
265+ verbose = verbose ,
266+ fixed_features = None ,
267+ )
268+ clamped_candidates = _no_fixed_features .acquisition_function ._construct_X_full (
269+ clamped_candidates
270+ )
271+ return clamped_candidates , batch_acquisition
272+
202273 clamped_candidates = columnwise_clamp (
203274 X = initial_conditions , lower = lower_bounds , upper = upper_bounds
204275 ).requires_grad_ (True )
205- candidates = fix_features (clamped_candidates , fixed_features )
206276 bayes_optimizer = optimizer (
207277 params = [clamped_candidates ], lr = options .get ("lr" , 0.025 )
208278 )
@@ -215,29 +285,33 @@ def gen_candidates_torch(
215285 )
216286 while not stop :
217287 i += 1
218- loss = - acquisition_function (candidates ).sum ()
288+ loss = - acquisition_function (clamped_candidates ).sum ()
219289 if verbose :
220290 print ("Iter: {} - Value: {:.3f}" .format (i , - (loss .item ())))
221291 loss_trajectory .append (loss .item ())
222- param_trajectory ["candidates" ].append (candidates .clone ())
292+ param_trajectory ["candidates" ].append (clamped_candidates .clone ())
223293
224294 def closure ():
225295 bayes_optimizer .zero_grad ()
226296 loss .backward ()
227297 return loss
228298
229299 bayes_optimizer .step (closure )
230- clamped_candidates . data = columnwise_clamp (
231- clamped_candidates , lower_bounds , upper_bounds
232- )
233- candidates = fix_features ( clamped_candidates , fixed_features )
300+ with torch . no_grad ():
301+ clamped_candidates = columnwise_clamp (
302+ X = clamped_candidates , lower = lower_bounds , upper = upper_bounds
303+ )
234304 stop = stopping_criterion .evaluate (fvals = loss .detach ())
235305 clamped_candidates = columnwise_clamp (
236- X = candidates , lower = lower_bounds , upper = upper_bounds , raise_on_violation = True
306+ X = clamped_candidates ,
307+ lower = lower_bounds ,
308+ upper = upper_bounds ,
309+ raise_on_violation = True ,
237310 )
238311 with torch .no_grad ():
239- batch_acquisition = acquisition_function (candidates )
240- return candidates , batch_acquisition
312+ batch_acquisition = acquisition_function (clamped_candidates )
313+
314+ return clamped_candidates , batch_acquisition
241315
242316
243317def get_best_candidates (batch_candidates : Tensor , batch_values : Tensor ) -> Tensor :
0 commit comments