-
Notifications
You must be signed in to change notification settings - Fork 456
Description
What happened?
Hello, botorch developers.
I am using botorch to tackle multi-objective optimization problems with parameter constraints and outcome constraints.
After upgrading from botorch 0.14 to 0.16.1, I noticed that the time spent inside optimize_acqf, specifically in batch_initial_conditions, increased by roughly a factor of 1.5.
By timing each commit using the attached script, I found that the slowdown is introduced by the changes in #2920. I understand that this commit is intended to address an OOM issue by improving memory usage and execution efficiency, but I am not sure why it results in slower performance for my cases. However, when outcome constraints are present, I observe a similar slowdown not only in the problem described in the attached script but also in several other problems, so it does not seem to depend on a specific optimization setup.
Please provide a minimal, reproducible example of the unexpected behavior.
from time import time
import numpy as np
import torch
from gpytorch.mlls import SumMarginalLogLikelihood
import botorch
from botorch.models import SingleTaskGP, ModelListGP
from botorch.utils.transforms import normalize, standardize
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition.multi_objective import qLogExpectedHypervolumeImprovement
from botorch.acquisition.multi_objective import IdentityMCMultiOutputObjective
from botorch.utils.multi_objective.box_decompositions import NondominatedPartitioning
from botorch.sampling import SobolQMCNormalSampler
from botorch.optim import optimize_acqf
from botorch.optim.initializers import gen_batch_initial_conditions
# gen_batch_initial_conditions with time check
def time_check_gen_batch_initial_conditions(*args, **kwargs):
start = time()
out = gen_batch_initial_conditions(*args, **kwargs)
end = time()
print(f' gen_batch_initial_conditions: {end - start :.2f} sec.')
return out
# Multi objective function
def objectives(X: torch.Tensor):
return X
# function for outcome constraint
def constraint_violations(X: torch.Tensor):
"""x + y <= 10"""
x = X[:, 0].unsqueeze(-1)
y = X[:, 1].unsqueeze(-1)
return x + y - 10.
# function for monte_carlo constraint
def cv_func_monte_carlo(Z, idx):
return Z[..., idx]
def main():
print(f'botorch version: {botorch.__version__}')
for seed in range(3):
np.random.seed(seed)
with botorch.manual_seed(seed):
# initial data
N = 50
bounds = np.array(((0, 10), (0, 10)), dtype=float)
x = np.concatenate(
[
np.random.rand(N, 1) * (bounds[i][1] - bounds[i][0]) + bounds[i][0]
for i in range(len(bounds))
], axis=-1
)
# make tensor
B = torch.tensor(bounds).transpose(0, 1)
X = torch.tensor(x)
# calc feasibility
CV = constraint_violations(X)
feas_idx = (CV <= 0).all(dim=-1)
# calc objectives
feas_X = X[feas_idx]
feas_Y = objectives(feas_X)
# print
print(f' {len(feas_X)=}')
# detach, transform
feas_X = normalize(feas_X.detach().clone(), bounds=B)
feas_Y = standardize(feas_Y.detach().clone())
X = normalize(X.detach().clone(), bounds=B)
CV = standardize(CV.detach().clone())
# train objective model
model = SingleTaskGP(
train_X=feas_X,
train_Y=feas_Y,
input_transform=None,
outcome_transform=None,
)
# train constraint model
model_con = SingleTaskGP(
train_X=X.detach().clone(),
train_Y=CV.detach().clone(),
input_transform=None,
outcome_transform=None,
)
# integrate them
models = ModelListGP(model, model_con)
mll = SumMarginalLogLikelihood(likelihood=models.likelihood, model=models)
fit_gpytorch_mll(mll)
# ACQF setup
objective = IdentityMCMultiOutputObjective(
outcomes=list(range(feas_Y.size(-1))),
)
constraints = [
lambda Z: cv_func_monte_carlo(Z, i + feas_Y.size(-1))
for i in range(CV.size(-1))
]
sampler = SobolQMCNormalSampler(
sample_shape=torch.Size((256,))
)
ref_point = feas_Y.min(dim=0).values.detach().clone()
ref_point -= 1e-8
alpha = 0.0
partitioning = NondominatedPartitioning(
ref_point=ref_point,
Y=feas_Y,
alpha=alpha,
)
acqf = qLogExpectedHypervolumeImprovement(
model=models,
ref_point=ref_point.tolist(),
partitioning=partitioning,
sampler=sampler,
objective=objective,
constraints=constraints,
)
# optimize acqf with
options = {
"batch_limit": 1, # requires for non-linear inequality constraints
"maxiter": 200,
}
start = time()
candidate, acqf_value = optimize_acqf(
acq_function=acqf,
bounds=B,
q=1,
num_restarts=20,
raw_samples=1024,
options=options,
return_best_only=True,
sequential=True,
# nonlinear_inequality_constraints=...,
ic_generator=time_check_gen_batch_initial_conditions,
)
end = time()
print(f' Elapsed time: {end - start: .2f} sec.')
if __name__ == '__main__':
main()Execution summary of gen_batch_initial_conditions time
batch_limit = 1
| botorch version | mean ± stddev (sec) |
|---|---|
| 0.14.0 | 7.87 ± 0.53 |
| 0.14.1.dev41 | 7.74 ± 0.53 |
| 0.14.1.dev42 | 12.59 ± 1.38 |
| 0.16.1 | 11.51 ± 0.56 |
batch_limit = 5
| botorch version | mean ± stddev (sec) |
|---|---|
| 0.14.0 | 2.92 ± 0.36 |
| 0.14.1.dev41 | 2.98 ± 0.50 |
| 0.14.1.dev42 | 3.92 ± 0.79 |
| 0.16.1 | 3.92 ± 0.81 |
all data
batch_limit: 1
botorch version: 0.14.0
len(feas_X)=29
gen_batch_initial_conditions: 8.33 sec.
Elapsed time: 12.07 sec.
len(feas_X)=25
gen_batch_initial_conditions: 7.29 sec.
Elapsed time: 10.95 sec.
len(feas_X)=28
gen_batch_initial_conditions: 7.99 sec.
Elapsed time: 12.36 sec.
(commit: 44299a1)
botorch version: 0.14.1.dev41+g44299a143
len(feas_X)=29
gen_batch_initial_conditions: 8.28 sec.
Elapsed time: 11.97 sec.
len(feas_X)=25
gen_batch_initial_conditions: 7.23 sec.
Elapsed time: 11.00 sec.
len(feas_X)=28
gen_batch_initial_conditions: 7.72 sec.
Elapsed time: 12.30 sec.
(commit: 71f03ea)
botorch version: 0.14.1.dev42+g71f03ea20
len(feas_X)=29
gen_batch_initial_conditions: 14.11 sec.
Elapsed time: 19.73 sec.
len(feas_X)=25
gen_batch_initial_conditions: 12.22 sec.
Elapsed time: 17.29 sec.
len(feas_X)=28
gen_batch_initial_conditions: 11.43 sec.
Elapsed time: 17.30 sec.
botorch version: 0.16.1
len(feas_X)=29
gen_batch_initial_conditions: 12.04 sec.
Elapsed time: 16.83 sec.
len(feas_X)=25
gen_batch_initial_conditions: 10.93 sec.
Elapsed time: 15.95 sec.
len(feas_X)=28
gen_batch_initial_conditions: 11.56 sec.
Elapsed time: 17.43 sec.
batch_limit: 5
botorch version: 0.14.0
len(feas_X)=29
gen_batch_initial_conditions: 3.31 sec.
Elapsed time: 9.65 sec.
len(feas_X)=25
gen_batch_initial_conditions: 2.60 sec.
Elapsed time: 7.43 sec.
len(feas_X)=28
gen_batch_initial_conditions: 2.85 sec.
Elapsed time: 9.11 sec.
(commit: 44299a1)
botorch version: 0.14.1.dev41+g44299a143
len(feas_X)=29
gen_batch_initial_conditions: 3.48 sec.
Elapsed time: 5.41 sec.
len(feas_X)=25
gen_batch_initial_conditions: 2.48 sec.
Elapsed time: 4.30 sec.
len(feas_X)=28
gen_batch_initial_conditions: 2.99 sec.
Elapsed time: 5.48 sec.
(commit: 71f03ea)
botorch version: 0.14.1.dev42+g71f03ea20
len(feas_X)=29
gen_batch_initial_conditions: 4.79 sec.
Elapsed time: 7.28 sec.
len(feas_X)=25
gen_batch_initial_conditions: 3.26 sec.
Elapsed time: 5.46 sec.
len(feas_X)=28
gen_batch_initial_conditions: 3.71 sec.
Elapsed time: 6.85 sec.
botorch version: 0.16.1
len(feas_X)=29
gen_batch_initial_conditions: 4.82 sec.
Elapsed time: 7.34 sec.
len(feas_X)=25
gen_batch_initial_conditions: 3.25 sec.
Elapsed time: 5.48 sec.
len(feas_X)=28
gen_batch_initial_conditions: 3.68 sec.
Elapsed time: 6.85 sec.
Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.14.0, 0.16.1
Python Version
3.13.7
Operating System
Windows 11
(Optional) Describe any potential fixes you've considered to the issue outlined above.
I’m not sure what kind of improvement would be appropriate, but one idea that comes to mind is to provide an option to use the previous implementation.
Thank you very much for your time and support.
Pull Request
None
Code of Conduct
- I agree to follow BoTorch's Code of Conduct