Skip to content

Commit 8cc5007

Browse files
optimization_bacteriarods: use bayesian optimization from sckit-optimize package and differential evolution method for comparison
1 parent 8c23037 commit 8cc5007

File tree

5 files changed

+116
-76
lines changed

5 files changed

+116
-76
lines changed

cr_bayesian_optim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@
2929
import cr_bayesian_optim.sim_branching as sim_branching
3030
import cr_bayesian_optim.plotting as plotting
3131
import cr_bayesian_optim.optimization as optimization
32-
from cr_bayesian_optim.optimize_bacterialrods import *
32+
import cr_bayesian_optim.optimize_bacterialrods as optimize_bacterialrods
3333

3434
from .fractal_dim import fractal_dim_main

cr_bayesian_optim/optimization.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
)
1212

1313
import numpy as np
14+
from scipy.optimize import differential_evolution
15+
from skopt import gp_minimize, callbacks
16+
import pickle
1417

1518

1619
def rhs_fractal_dim(options: Options) -> tuple[float, float]:
@@ -35,3 +38,54 @@ def rhs_fractal_dim(options: Options) -> tuple[float, float]:
3538

3639
_, _, popt, pcov = calculate_fractal_dim_for_pos(pos, options)
3740
return popt[0], pcov[0, 0] ** 0.5
41+
42+
43+
def optimization_diff_evolution(cost, bnds, args=(), workers=-1):
44+
return differential_evolution(cost,
45+
bounds=bnds,
46+
tol=1e-3,
47+
atol=1e-3,
48+
maxiter=10,
49+
#mutation=(0.3, 1.9),
50+
#recombination=0.7,
51+
popsize=5,
52+
init='latinhypercube',
53+
disp=True,
54+
polish=False,
55+
updating='deferred',
56+
workers=workers,
57+
strategy='randtobest1bin',
58+
callback=callback_diffevol)#, callback=callback_ll) #init='sobol'
59+
60+
61+
def callback_diffevol(intermediate_result):
62+
with open("out/Optimization_result_diffevol.pkl", 'wb') as outp: # Overwrites any existing file.
63+
pickle.dump(intermediate_result, outp, pickle.HIGHEST_PROTOCOL)
64+
65+
66+
def optimization_bayes(cost, bnds, args=(), workers=-1):
67+
return gp_minimize(cost,
68+
bnds,
69+
acq_func="EI", # the acquisition function: EI, LCB, MES, gp_hedge, PVRS, PI, EIps, PIps
70+
n_calls=20, # the number of evaluations of f
71+
n_random_starts=5, # the number of random initialization points
72+
noise=0., # the noise level (optional)
73+
random_state=1234,
74+
kappa=1.96,
75+
xi=0.01,
76+
acq_optimizer='lbfgs', # is needed for parallelization
77+
n_restarts_optimizer=5, # the number of restarts of the optimizer
78+
n_jobs=workers, # the number of parallel evaluations of f
79+
callback=[callbacks.CheckpointSaver("out/Optimization_result.pkl")], # a callback function to be called after each iteration
80+
)
81+
82+
83+
def save_optimization_result(res, path='', add_filename=''):
84+
with open(path+'Final_optimization_result'+add_filename+'.pkl', 'wb') as outp: # Overwrites any existing file.
85+
pickle.dump(res, outp, pickle.HIGHEST_PROTOCOL)
86+
87+
88+
def load_optimization_result(path='', add_filename=''):
89+
with open(path+'Final_optimization_result'+add_filename+'.pkl', 'rb') as inp:
90+
res = pickle.load(inp)
91+
return res
Lines changed: 33 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import cr_mech_coli as crm
22
import cr_mech_coli.crm_fit as crm_fit
3-
3+
from functools import partial
44
import numpy as np
5-
import matplotlib.pyplot as plt
6-
7-
from bayes_opt import BayesianOptimization, acquisition
85

96

107
def extract_data(image_timesteps, n_vertices):
@@ -16,54 +13,36 @@ def extract_data(image_timesteps, n_vertices):
1613
return data
1714

1815

19-
def cost(data, settings, init_pos, *param):
16+
def cost_bacterialrods(data, settings, init_pos, param):
2017
(days, x_target) = data
2118
container = crm_fit.predict(param, init_pos, settings)
2219
if container is None:
2320
print("Simulation Failed")
24-
exit()
25-
iterations = container.get_all_iterations()
26-
x_prediction = np.zeros(np.shape(x_target))
27-
delta_iter = np.mean(np.array(iterations)[1:]-np.array(iterations)[:-1])
28-
## TODO why is saved iterations step changes from 952 to 953 ??
29-
iter_data = delta_iter*(np.array(days)-days[0]+1)
30-
ind_last = np.argmin(np.abs(iter_data[-1]-iterations))
31-
i = 0
32-
for iter in iterations[:ind_last+1]:
33-
if np.any(np.abs(iter_data-iter) <= 1.):
34-
cells = container.get_cells_at_iteration(iter)
35-
keys = sorted(cells.keys())
36-
# what is the last dimension: why 3 and not 2 ?
37-
pos = np.array([cells[key][0].pos for key in keys])[:, :, :-1]
38-
x_prediction[i] = pos
39-
i += 1
40-
return np.mean(squared_difference(x_target, x_prediction))
21+
return 1e10
22+
else:
23+
iterations = container.get_all_iterations()
24+
x_prediction = np.zeros(np.shape(x_target))
25+
delta_iter = np.mean(np.array(iterations)[1:]-np.array(iterations)[:-1])
26+
## TODO why is saved iterations step changes from 952 to 953 ??
27+
iter_data = delta_iter*(np.array(days)-days[0]+1)
28+
ind_last = np.argmin(np.abs(iter_data[-1]-iterations))
29+
i = 0
30+
for iter in iterations[:ind_last+1]:
31+
if np.any(np.abs(iter_data-iter) <= 1.):
32+
cells = container.get_cells_at_iteration(iter)
33+
keys = sorted(cells.keys())
34+
# what is the last dimension: why 3 and not 2 ?
35+
pos = np.array([cells[key][0].pos for key in keys])[:, :, :-1]
36+
x_prediction[i] = pos
37+
i += 1
38+
return np.mean(squared_difference(x_target, x_prediction))
4139

4240

4341
def squared_difference(x_target, x_prediction):
4442
return (x_target-x_prediction)**2
4543

4644

47-
def posterior(optimizer, grid):
48-
mu, sigma = optimizer._gp.predict(grid, return_std=True)
49-
return mu, sigma
50-
51-
52-
def plot_objective_GP(optimizer, bnds, name=''):
53-
for k in bnds.keys():
54-
fig, ax = plt.subplots()
55-
x_gp = np.linspace(*bnds[k], 100)
56-
mean_gp, sigma_gp = posterior(optimizer, x_gp.reshape(-1, 1))
57-
ax.plot(x_gp, mean_gp, label=k)
58-
ax.fill_between(x_gp, mean_gp + sigma_gp, mean_gp - sigma_gp, alpha=0.1)
59-
ax.scatter(optimizer.space.params.flatten(), optimizer.space.target, c="red", s=50, zorder=10)
60-
ax.legend(fontsize=12)
61-
plt.savefig(f'{k}_{name}'+'.png', bbox_inches='tight')
62-
plt.close(fig)
63-
64-
65-
66-
def optimize_bacterialrods_main():
45+
def create_test_ABM_framework():
6746
n_vertices = 8
6847
# Extract data from masks which have been previously generated
6948
image_timesteps = ['42', '43', '44', '45', '46', '47', '48', '49', '52']
@@ -72,41 +51,20 @@ def optimize_bacterialrods_main():
7251
# Target/model/simulation
7352
# Define settings required to run simulation
7453
settings = crm_fit.Settings.from_toml("data/crm_fit/0001/settings.toml")
75-
settings.constants.n_vertices = n_vertices
7654
settings.constants.n_saves = 15
7755
settings.others = crm_fit.Others(True)
56+
return data, settings
57+
7858

79-
#settings.parameters.damping = crm_fit.SampledFloat(min=0, max=2.5, initial=1.5)
80-
settings.parameters.damping = 2.0
81-
settings.parameters.potential_type.Mie.en = 10.
82-
settings.parameters.potential_type.Mie.em = 1.5
59+
def main_bacterialrods_optimization(optimizer, update_ABM=None):
60+
# Define ABM framework
61+
data, settings = create_test_ABM_framework()
62+
if update_ABM is not None:
63+
settings = update_ABM(settings)
8364
lower, upper, x0, param_infos, constants, constant_infos = settings.generate_optimization_infos(len(data[1][0]))
65+
bnds_dict = {p_inf[0]: (u_b, l_b) for u_b, l_b, p_inf in zip(lower, upper, param_infos)}
8466
print(param_infos)
85-
86-
# Define the cost function with arguments as optimizes parameters:
87-
#cost_for_optimization = lambda Damping, Strength: cost(data, settings, data[1][0], Damping, Strength)
88-
#cost_for_optimization = lambda Damping: cost(data, settings, data[1][0], Damping)
89-
cost_for_optimization = lambda Strength: cost(data, settings, data[1][0], Strength)
90-
91-
N_iter = 20
92-
acq = acquisition.ExpectedImprovement(1.) #ProbabilityOfImprovement(1.) #UpperConfidenceBound(kappa=1.)#
93-
bnds = {p_inf[0]: (u_b, l_b) for u_b, l_b, p_inf in zip(lower, upper, param_infos)}
94-
optimizer = BayesianOptimization(
95-
f=None,
96-
acquisition_function=acq,
97-
pbounds=bnds,
98-
verbose=2,
99-
random_state=17695,
100-
)
101-
for j in range(N_iter):
102-
next_params = optimizer.suggest()
103-
target = cost_for_optimization(**next_params)
104-
optimizer.register(
105-
params=next_params,
106-
target=target,
107-
)
108-
plot_objective_GP(optimizer, bnds, name=f'EI_{j}')
109-
110-
111-
if __name__ == "__main__":
112-
optimize_bacterialrods_main()
67+
cost_for_optimization = partial(cost_bacterialrods, data, settings, data[1][0])
68+
res = optimizer(cost_for_optimization, [bnds_dict[k] for k in bnds_dict.keys()])
69+
print(res.x, res.fun)
70+
return res

cr_bayesian_optim/plotting.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import tqdm
1313
import multiprocessing as mp
1414
import itertools
15+
from skopt.plots import plot_gaussian_process, plot_convergence, plot_objective
1516

1617
# Define colors
1718
COLOR1 = "#0c457d"
@@ -185,3 +186,24 @@ def generate_movie(opath: Path, play_movie: bool = True):
185186
print("Playing Movie")
186187
bashcmd2 = f"firefox ./{opath}/movie.mp4"
187188
os.system(bashcmd2)
189+
190+
191+
# Plotting for bacteria rods optimization:
192+
def plot_optimization_convergence_bayes(res, path='', add_name=''):
193+
fig, ax = plt.subplots()
194+
plot_convergence(res)
195+
plt.savefig(path+'skopt_convergence'+add_name+'.png', bbox_inches='tight')
196+
plt.close(fig)
197+
198+
199+
def plot_1D_cost_approximation_bayes(res, path='', add_name=''):
200+
fig, ax = plt.subplots()
201+
_ = plot_gaussian_process(res)
202+
plt.savefig(path+'cost_approximation'+add_name+'.png', bbox_inches='tight')
203+
plt.close(fig)
204+
205+
def plot_objective_projection_bayes(res, path='', add_name=''):
206+
fig, ax = plt.subplots()
207+
_ = plot_objective(res, n_points=10)
208+
plt.savefig(path+'objective_projection'+add_name+'.png', bbox_inches='tight')
209+
plt.close(fig)

docs/requirements.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,9 @@ setuptools==80.9.0
55
myst_parser==4.0.1
66
sphinx_subfigure==0.2.4
77
bayesian-optimization==3.0.0
8+
scikit-optimize==0.10.2
9+
numpy==2.3.0
10+
scipy==1.15.3
11+
joblib==1.5.1
12+
scikit-learn==1.7.0
13+
matplotlib==3.10.3

0 commit comments

Comments
 (0)