Skip to content

Commit b31d415

Browse files
authored
Adds example: lambo2 on foldx (#52)
* Updates the examples folder with recent changes * Adds a basic example on how to run lambo2 on foldx stability * Exposes number of mutations * Starts prototyping imposing a wt for candidate points * Allows for imposing a batch of custom candidate points * Implements a saving mechanism from the history * Adds maximum sequence length to the generic training configs * removes unsued imports * Adds a readme to the example * Removes a comment in lambo2 solver script * Removes some TODO comments
1 parent dcbe4ed commit b31d415

File tree

13 files changed

+1239
-11
lines changed

13 files changed

+1239
-11
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ examples/09_replicating_nsga_ii_of_lambo_by_hand/repaired_pdbs
2323
examples/09_replicating_nsga_ii_of_lambo_by_hand/pHs.json
2424
examples/09_replicating_nsga_ii_of_lambo_by_hand/history/
2525
examples/09_replicating_nsga_ii_of_lambo_by_hand/history.json
26+
examples/06_running_lambo2_on_rasp/*.npz
27+
examples/07_running_lambo2_on_foldx/tmp/
28+
examples/07_running_lambo2_on_foldx/*.csv
29+
examples/07_running_lambo2_on_foldx/lambo2_trace.npz
2630
examples/ignore*
2731

2832
# BAxUS related stuff

examples/06_running_lambo2_on_rasp/run.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from poli.objective_repository import RaspProblemFactory
88

9-
from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2
9+
from poli_baselines.solvers.simple.random_mutation import RandomMutation
1010

1111
THIS_DIR = Path(__file__).resolve().parent
1212
sys.path.append(str(THIS_DIR))
@@ -15,6 +15,8 @@
1515

1616

1717
def run_with_default_hyperparameters():
18+
from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2
19+
1820
RFP_PDBS_DIR = THIS_DIR / "rfp_pdbs"
1921
ALL_PDBS = list(RFP_PDBS_DIR.rglob("**/*.pdb"))
2022
problem = RaspProblemFactory().create(
@@ -69,6 +71,8 @@ def run_with_modified_hyperparameters():
6971
You can find the original configuration we use here:
7072
src/poli_baselines/solvers/bayesian_optimization/lambo2/hydra_configs
7173
"""
74+
from poli_baselines.solvers.bayesian_optimization.lambo2 import LaMBO2
75+
7276
POPULATION_SIZE = 96
7377
MAX_EPOCHS_FOR_PRETRAINING = 4
7478

@@ -112,6 +116,61 @@ def run_with_modified_hyperparameters():
112116
black_box.terminate()
113117

114118

119+
def comparing_against_directed_evolution():
120+
arr = np.load(THIS_DIR / "rasp_seed_data.npz")
121+
x0 = arr["x0"]
122+
y0 = arr["y0"]
123+
batch_size = 128
124+
n_iterations = 32
125+
126+
x0_for_solver_no_padding = x0[np.argsort(y0.flatten())[::-1]][:batch_size]
127+
128+
# Adding padding
129+
max_length = max(map(len, x0_for_solver_no_padding))
130+
x0_for_solver_ = [[char for char in x_i] for x_i in x0_for_solver_no_padding]
131+
x0_for_solver = np.array(
132+
[list_ + ([""] * (max_length - len(list_))) for list_ in x0_for_solver_]
133+
)
134+
135+
y0_for_solver = y0[np.argsort(y0.flatten())[::-1]][:batch_size]
136+
137+
RFP_PDBS_DIR = THIS_DIR / "rfp_pdbs"
138+
ALL_PDBS = list(RFP_PDBS_DIR.rglob("**/*.pdb"))
139+
problem = RaspProblemFactory().create(
140+
wildtype_pdb_path=ALL_PDBS,
141+
additive=True,
142+
chains_to_keep=[p.parent.name.split("_")[1] for p in ALL_PDBS],
143+
)
144+
black_box = problem.black_box
145+
146+
observer = SimpleObserver()
147+
black_box.set_observer(observer)
148+
149+
observer.x_s.append(x0.reshape(-1, 1))
150+
observer.y_s.append(y0)
151+
152+
directed_evolution = RandomMutation(
153+
black_box=black_box,
154+
x0=x0_for_solver,
155+
y0=y0_for_solver,
156+
batch_size=batch_size,
157+
)
158+
max_eval = n_iterations * batch_size
159+
directed_evolution.solve(max_iter=max_eval // batch_size, verbose=True)
160+
observer.save_history(
161+
THIS_DIR / f"directed_evolution_rasp_trace_b_{batch_size}_i_{n_iterations}.npz"
162+
)
163+
164+
fig, (ax1, ax2) = plt.subplots(1, 2)
165+
plot_best_y(observer, ax1)
166+
plot_best_y(observer, ax2, start_from=x0.shape[0])
167+
ax1.axvline(x0.shape[0], color="red")
168+
plt.show()
169+
170+
black_box.terminate()
171+
172+
115173
if __name__ == "__main__":
116-
run_with_default_hyperparameters()
174+
# run_with_default_hyperparameters()
117175
# run_with_modified_hyperparameters()
176+
comparing_against_directed_evolution()

examples/06_running_lambo2_on_rasp/simple_observer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
1818
self.x_s.append(x)
1919
self.y_s.append(y)
2020

21+
def save_history(self, path: str) -> None:
22+
arr_x = []
23+
for x in self.x_s:
24+
arr_x.append(np.array(["".join(x_i) for x_i in x]))
25+
x_s = np.concatenate(arr_x)
26+
y_s = np.vstack(self.y_s)
27+
np.savez(path, x_s=x_s, y_s=y_s)
28+
2129

2230
def plot_best_y(obs: SimpleObserver, ax: plt.Axes, start_from: int = 0):
2331
best_y = np.maximum.accumulate(np.vstack(obs.y_s).flatten())
Binary file not shown.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
This folder includes an example in which we optimize the thermal stability two PDBs (DNJA1 and RFAH), measured using `foldx`, using `LaMBO2`.
2+
3+
As a pre-requisite, [we encourage you to set-up `poli` for `foldx`](https://machinelearninglifescience.github.io/poli-docs/using_poli/objective_repository/foldx_stability.html).
4+
5+
We recommend running it inside the environment of `LaMBO2`, which you can find inside the `solvers` folder.
6+
7+
```bash
8+
# From the root of the poli-baselines directory
9+
pip install -e .[lambo2]
10+
python examples/07_running_lambo2_on_foldx/run.py
11+
```

0 commit comments

Comments
 (0)