|
6 | 6 | import torch
|
7 | 7 | from poli.objective_repository import RaspProblemFactory
|
8 | 8 |
|
9 |
| -from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2 |
| 9 | +from poli_baselines.solvers.simple.random_mutation import RandomMutation |
10 | 10 |
|
11 | 11 | THIS_DIR = Path(__file__).resolve().parent
|
12 | 12 | sys.path.append(str(THIS_DIR))
|
|
15 | 15 |
|
16 | 16 |
|
17 | 17 | def run_with_default_hyperparameters():
|
| 18 | + from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2 |
| 19 | + |
18 | 20 | RFP_PDBS_DIR = THIS_DIR / "rfp_pdbs"
|
19 | 21 | ALL_PDBS = list(RFP_PDBS_DIR.rglob("**/*.pdb"))
|
20 | 22 | problem = RaspProblemFactory().create(
|
@@ -69,6 +71,8 @@ def run_with_modified_hyperparameters():
|
69 | 71 | You can find the original configuration we use here:
|
70 | 72 | src/poli_baselines/solvers/bayesian_optimization/lambo2/hydra_configs
|
71 | 73 | """
|
| 74 | + from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2 |
| 75 | + |
72 | 76 | POPULATION_SIZE = 96
|
73 | 77 | MAX_EPOCHS_FOR_PRETRAINING = 4
|
74 | 78 |
|
@@ -112,6 +116,61 @@ def run_with_modified_hyperparameters():
|
112 | 116 | black_box.terminate()
|
113 | 117 |
|
114 | 118 |
|
| 119 | +def comparing_against_directed_evolution(): |
| 120 | + arr = np.load(THIS_DIR / "rasp_seed_data.npz") |
| 121 | + x0 = arr["x0"] |
| 122 | + y0 = arr["y0"] |
| 123 | + batch_size = 128 |
| 124 | + n_iterations = 32 |
| 125 | + |
| 126 | + x0_for_solver_no_padding = x0[np.argsort(y0.flatten())[::-1]][:batch_size] |
| 127 | + |
| 128 | + # Adding padding |
| 129 | + max_length = max(map(len, x0_for_solver_no_padding)) |
| 130 | + x0_for_solver_ = [[char for char in x_i] for x_i in x0_for_solver_no_padding] |
| 131 | + x0_for_solver = np.array( |
| 132 | + [list_ + ([""] * (max_length - len(list_))) for list_ in x0_for_solver_] |
| 133 | + ) |
| 134 | + |
| 135 | + y0_for_solver = y0[np.argsort(y0.flatten())[::-1]][:batch_size] |
| 136 | + |
| 137 | + RFP_PDBS_DIR = THIS_DIR / "rfp_pdbs" |
| 138 | + ALL_PDBS = list(RFP_PDBS_DIR.rglob("**/*.pdb")) |
| 139 | + problem = RaspProblemFactory().create( |
| 140 | + wildtype_pdb_path=ALL_PDBS, |
| 141 | + additive=True, |
| 142 | + chains_to_keep=[p.parent.name.split("_")[1] for p in ALL_PDBS], |
| 143 | + ) |
| 144 | + black_box = problem.black_box |
| 145 | + |
| 146 | + observer = SimpleObserver() |
| 147 | + black_box.set_observer(observer) |
| 148 | + |
| 149 | + observer.x_s.append(x0.reshape(-1, 1)) |
| 150 | + observer.y_s.append(y0) |
| 151 | + |
| 152 | + directed_evolution = RandomMutation( |
| 153 | + black_box=black_box, |
| 154 | + x0=x0_for_solver, |
| 155 | + y0=y0_for_solver, |
| 156 | + batch_size=batch_size, |
| 157 | + ) |
| 158 | + max_eval = n_iterations * batch_size |
| 159 | + directed_evolution.solve(max_iter=max_eval // batch_size, verbose=True) |
| 160 | + observer.save_history( |
| 161 | + THIS_DIR / f"directed_evolution_rasp_trace_b_{batch_size}_i_{n_iterations}.npz" |
| 162 | + ) |
| 163 | + |
| 164 | + fig, (ax1, ax2) = plt.subplots(1, 2) |
| 165 | + plot_best_y(observer, ax1) |
| 166 | + plot_best_y(observer, ax2, start_from=x0.shape[0]) |
| 167 | + ax1.axvline(x0.shape[0], color="red") |
| 168 | + plt.show() |
| 169 | + |
| 170 | + black_box.terminate() |
| 171 | + |
| 172 | + |
115 | 173 | if __name__ == "__main__":
|
116 |
| - run_with_default_hyperparameters() |
| 174 | + # run_with_default_hyperparameters() |
117 | 175 | # run_with_modified_hyperparameters()
|
| 176 | + comparing_against_directed_evolution() |
0 commit comments