Skip to content

Commit 1d9aefe

Browse files
lwalewchrbrunk
andauthored
feat: allow reusing model output from folding stability for sampling (#93)
* feat: allow reusing model output for sampling * feat: attempt removing circular import * feat: common `run_model()` for FS and S * feat: update analysis * feat: update biomolecules dev configuration * feat: temporarily shorten sampling MD length * feat: decrease num steps from 1ns->250ps for other models than 1922 * feat: increase num steps to 1ns * feat: decrease num steps from 1ns->250ps for other models than 1922 Use 25 episodes to have valid configuration * feat: increase num steps to 1ns * feat: decrease num steps from 1ns->250ps for other models than 1922 * feat: update num steps for sampling * feat: remove `run_biomolecules()` * feat: reuse outputs in cli * feat: check that we have the correct structures * fix: single backbone dihedrals * fix: subset check * feat: improve first draft on reusable model outputs * fix: sampling error (#97) * fix: get dihedrals per unique name * chore: remove comments * fix: sampling tests * fix: remove ILE * docs: update tutorial for adding new benchmark in docs * feat: better asserts for benchmarks which reuse model outputs * test: update unit tests with patch for assertion of model output transfer --------- Co-authored-by: Christoph Brunken <c.brunken@instadeep.com>
1 parent 255f838 commit 1d9aefe

File tree

21 files changed

+3594
-4184
lines changed

21 files changed

+3594
-4184
lines changed

docs/source/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Benchmark implementations
3535
small_molecules/radial_distribution
3636
small_molecules/solvent_radial_distribution
3737
small_molecules/reactivity
38+
small_molecules/nudged_elastic_band
3839
biomolecules/folding_stability
3940
biomolecules/sampling
4041
general/stability

docs/source/tutorials/new_benchmark/index.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,17 @@ Here is an example of a very minimal new benchmark implementation:
127127
The data loading as a cached property is only recommended if the loaded data
128128
is need in both the `run_model()` and the `analyze()` functions.
129129

130-
Note that the functions `_compute_energies_blackbox` and `_analyze_blackbox`` are
130+
Note that the functions `_compute_energies_blackbox` and `_analyze_blackbox` are
131131
placeholders for the actual implementations.
132132

133+
Another class attribute that can be specified optionally is `reusable_output_id`,
134+
which is `None` by default. It can be used to signal that two benchmarks use the exact
135+
same `run_model()` method and the exact same signature for the model output class.
136+
This ID should be of type tuple with the names of the benchmarks in it, see the
137+
benchmarks `Sampling` and `FoldingStability` for an example of this. See the source code
138+
of the main benchmarking script for how it reuses the model output of one for the other
139+
benchmark without rerunning any simulation or inference.
140+
133141
**Furthermore, you need to add an import for your benchmark to the**
134142
`src/mlipaudit/benchmarks/__init__.py` **file such that the benchmark can be**
135143
**automatically picked up by the CLI tool.**

src/mlipaudit/benchmark.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ class Benchmark(ABC):
7676
if there are some element types that the model cannot handle. If False,
7777
the benchmark must have its own custom logic to handle missing element
7878
types. Defaults to True.
79+
reusable_output_id: An optional ID that references other benchmarks with
80+
identical input systems and `ModelOutput` signatures (in form of a tuple).
81+
If present, a user or the CLI can make use of this information to reuse
82+
cached model outputs from another benchmark carrying the same ID instead of
83+
rerunning simulations or inference.
7984
"""
8085

8186
name: str = ""
@@ -86,6 +91,8 @@ class Benchmark(ABC):
8691
required_elements: set[str] | None = None
8792
skip_if_elements_missing: bool = True
8893

94+
reusable_output_id: tuple[str, ...] | None = None
95+
8996
def __init__(
9097
self,
9198
force_field: ForceField | ASECalculator,

src/mlipaudit/benchmarks/folding_stability/folding_stability.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
create_mdtraj_trajectory_from_simulation_state,
3434
get_simulation_engine,
3535
)
36+
from mlipaudit.utils.simulation import REUSABLE_BIOMOLECULES_OUTPUTS_ID
3637
from mlipaudit.utils.stability import is_simulation_stable
3738

3839
logger = logging.getLogger("mlipaudit")
@@ -50,16 +51,16 @@
5051
}
5152

5253
SIMULATION_CONFIG = {
53-
"num_steps": 100_000,
54-
"snapshot_interval": 100,
55-
"num_episodes": 100,
54+
"num_steps": 250_000,
55+
"snapshot_interval": 10_000,
56+
"num_episodes": 25,
5657
"temperature_kelvin": 300.0,
5758
}
5859

5960
SIMULATION_CONFIG_DEV = {
60-
"num_steps": 10,
61+
"num_steps": 5,
6162
"snapshot_interval": 1,
62-
"num_episodes": 10,
63+
"num_episodes": 1,
6364
"temperature_kelvin": 300.0,
6465
}
6566

@@ -168,6 +169,11 @@ class FoldingStabilityBenchmark(Benchmark):
168169
if there are some atomic element types that the model cannot handle. If
169170
False, the benchmark must have its own custom logic to handle missing atomic
170171
element types. For this benchmark, the attribute is set to True.
172+
reusable_output_id: An optional ID that references other benchmarks with
173+
identical input systems and `ModelOutput` signatures (in form of a tuple).
174+
If present, a user or the CLI can make use of this information to reuse
175+
cached model outputs from another benchmark carrying the same ID instead of
176+
rerunning simulations or inference.
171177
"""
172178

173179
name = "folding_stability"
@@ -177,16 +183,13 @@ class FoldingStabilityBenchmark(Benchmark):
177183

178184
required_elements = {"H", "N", "O", "S", "C"}
179185

186+
reusable_output_id = REUSABLE_BIOMOLECULES_OUTPUTS_ID
187+
180188
def run_model(self) -> None:
181189
"""Run an MD simulation for each biosystem.
182190
183191
The simulation results are stored in the `model_output` attribute.
184192
"""
185-
self.model_output = FoldingStabilityModelOutput(
186-
structure_names=[],
187-
simulation_states=[],
188-
)
189-
190193
if self.run_mode == RunMode.DEV:
191194
structure_names = STRUCTURE_NAMES[:1]
192195
elif self.run_mode == RunMode.FAST:
@@ -199,10 +202,17 @@ def run_model(self) -> None:
199202
else:
200203
md_kwargs = SIMULATION_CONFIG
201204

205+
self.model_output = FoldingStabilityModelOutput(
206+
structure_names=[],
207+
simulation_states=[],
208+
)
209+
202210
for structure_name in structure_names:
203211
logger.info("Running MD for %s", structure_name)
204212
xyz_filename = structure_name + ".xyz"
205-
atoms = ase_read(self.data_input_dir / self.name / xyz_filename)
213+
atoms = ase_read(
214+
self.data_input_dir / self.name / "starting_structures" / xyz_filename
215+
)
206216

207217
md_engine = get_simulation_engine(
208218
atoms, self.force_field, box=BOX_SIZES[structure_name], **md_kwargs
@@ -230,6 +240,8 @@ def analyze(self) -> FoldingStabilityResult:
230240
if self.model_output is None:
231241
raise RuntimeError("Must call run_model() first.")
232242

243+
self._assert_structure_names_in_model_output()
244+
233245
molecule_results = []
234246
num_stable = 0
235247

@@ -246,13 +258,15 @@ def analyze(self) -> FoldingStabilityResult:
246258
continue
247259

248260
num_stable += 1
249-
250-
topology_filename = structure_name + ".pdb"
251-
ref_filename = structure_name + "_ref.pdb"
261+
box_size = BOX_SIZES[structure_name]
252262

253263
mdtraj_traj_solv = create_mdtraj_trajectory_from_simulation_state(
254264
simulation_state,
255-
topology_path=self.data_input_dir / self.name / topology_filename,
265+
topology_path=self.data_input_dir
266+
/ self.name
267+
/ "pdb_reference_structures"
268+
/ f"{structure_name}.pdb",
269+
cell_lengths=box_size, # type: ignore
256270
)
257271
ase_traj_solv = create_ase_trajectory_from_simulation_state(
258272
simulation_state
@@ -271,14 +285,20 @@ def analyze(self) -> FoldingStabilityResult:
271285
# 2. Match in secondary structure (from DSSP)
272286
match_secondary_structure = get_match_secondary_structure(
273287
mdtraj_traj,
274-
ref_path=self.data_input_dir / self.name / ref_filename,
288+
ref_path=self.data_input_dir
289+
/ self.name
290+
/ "pdb_reference_structures"
291+
/ f"{structure_name}_ref.pdb",
275292
simplified=False,
276293
)
277294

278295
# 3. TM-score and RMSD
279296
tm_scores, rmsd_values = compute_tm_scores_and_rmsd_values(
280297
mdtraj_traj,
281-
self.data_input_dir / self.name / ref_filename,
298+
self.data_input_dir
299+
/ self.name
300+
/ "pdb_reference_structures"
301+
/ f"{structure_name}_ref.pdb",
282302
)
283303

284304
initial_rg = rg_values[0]
@@ -333,3 +353,14 @@ def analyze(self) -> FoldingStabilityResult:
333353
),
334354
score=score,
335355
)
356+
357+
def _assert_structure_names_in_model_output(self) -> None:
358+
"""Asserts whether model output structure names are fine as potentially they
359+
have been transferred from a different benchmark.
360+
"""
361+
assert set(self.model_output.structure_names).issubset(STRUCTURE_NAMES) # type: ignore
362+
assert len(self.model_output.structure_names) == ( # type: ignore
363+
1
364+
if self.run_mode == RunMode.DEV
365+
else (2 if self.run_mode == RunMode.FAST else len(STRUCTURE_NAMES))
366+
)

src/mlipaudit/benchmarks/sampling/helpers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,13 @@ def get_all_dihedrals_from_trajectory(
188188
dihedrals[residue] = {}
189189
dihedrals[residue][dihedral_name] = angles_deg[:, i]
190190

191-
return dihedrals
191+
# Drop residues which don't contain both backbone dihedrals phi and psi
192+
filtered_dihedrals = {}
193+
for residue, dihedrals in dihedrals.items():
194+
if not ("phi" in dihedrals) ^ ("psi" in dihedrals):
195+
filtered_dihedrals[residue] = dihedrals
196+
197+
return filtered_dihedrals
192198

193199

194200
def identify_outlier_data_points(

src/mlipaudit/benchmarks/sampling/sampling.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,35 @@
3535
create_mdtraj_trajectory_from_simulation_state,
3636
get_simulation_engine,
3737
)
38+
from mlipaudit.utils.simulation import REUSABLE_BIOMOLECULES_OUTPUTS_ID
3839
from mlipaudit.utils.stability import is_simulation_stable
3940

4041
logger = logging.getLogger("mlipaudit")
4142

4243
STRUCTURE_NAMES = [
43-
"thr_ile_solv",
44-
"asn_asp_solv",
45-
"tyr_trp_solv",
44+
"chignolin_1uao_xray",
45+
"trp_cage_2jof_xray",
46+
"orexin_beta_1cq0_nmr",
4647
]
4748

48-
CUBIC_BOX_SIZES = {
49-
"thr_ile_solv": 21.802,
50-
"asn_asp_solv": 21.806,
51-
"tyr_trp_solv": 24.012,
49+
BOX_SIZES = {
50+
"chignolin_1uao_xray": [23.98, 22.45, 20.68],
51+
"trp_cage_2jof_xray": [29.33, 29.74, 23.59],
52+
"orexin_beta_1cq0_nmr": [40.30, 29.56, 33.97],
5253
}
5354

5455
SIMULATION_CONFIG = {
55-
"num_steps": 150_000,
56-
"snapshot_interval": 1000,
57-
"num_episodes": 150,
58-
"temperature_kelvin": 350.0,
56+
"num_steps": 250_000,
57+
"snapshot_interval": 10_000,
58+
"num_episodes": 25,
59+
"temperature_kelvin": 300.0,
5960
}
6061

61-
SIMULATION_CONFIG_FAST = {
62-
"num_steps": 1,
62+
SIMULATION_CONFIG_DEV = {
63+
"num_steps": 5,
6364
"snapshot_interval": 1,
6465
"num_episodes": 1,
65-
"temperature_kelvin": 350.0,
66+
"temperature_kelvin": 300.0,
6667
}
6768

6869
RESNAME_TO_BACKBONE_RESIDUE_TYPE = {
@@ -260,6 +261,11 @@ class SamplingBenchmark(Benchmark):
260261
if there are some atomic element types that the model cannot handle. If
261262
False, the benchmark must have its own custom logic to handle missing atomic
262263
element types. For this benchmark, the attribute is set to True.
264+
reusable_output_id: An optional ID that references other benchmarks with
265+
identical input systems and `ModelOutput` signatures (in form of a tuple).
266+
If present, a user or the CLI can make use of this information to reuse
267+
cached model outputs from another benchmark carrying the same ID instead of
268+
rerunning simulations or inference.
263269
"""
264270

265271
name = "sampling"
@@ -269,36 +275,37 @@ class SamplingBenchmark(Benchmark):
269275

270276
required_elements = {"N", "H", "O", "S", "C"}
271277

278+
reusable_output_id = REUSABLE_BIOMOLECULES_OUTPUTS_ID
279+
272280
def run_model(self) -> None:
273281
"""Run an MD simulation for each system."""
274-
self.model_output = SamplingModelOutput(
275-
structure_names=[],
276-
simulation_states=[],
277-
)
278-
279282
if self.run_mode == RunMode.DEV:
280-
md_config_dict = SIMULATION_CONFIG_FAST
281-
structure_names = ["thr_ile_solv"]
283+
structure_names = STRUCTURE_NAMES[:1]
282284
elif self.run_mode == RunMode.FAST:
283-
md_config_dict = SIMULATION_CONFIG
284-
structure_names = ["thr_ile_solv", "asn_asp_solv"]
285+
structure_names = STRUCTURE_NAMES[:2]
285286
else:
286-
md_config_dict = SIMULATION_CONFIG
287287
structure_names = STRUCTURE_NAMES
288288

289+
if self.run_mode == RunMode.DEV:
290+
md_kwargs = SIMULATION_CONFIG_DEV
291+
else:
292+
md_kwargs = SIMULATION_CONFIG
293+
294+
self.model_output = SamplingModelOutput(
295+
structure_names=[],
296+
simulation_states=[],
297+
)
298+
289299
for structure_name in structure_names:
290300
logger.info("Running MD for %s", structure_name)
291301
xyz_filename = structure_name + ".xyz"
292-
box_size = CUBIC_BOX_SIZES[structure_name]
293-
md_kwargs = dict(
294-
box=box_size,
295-
**md_config_dict,
296-
)
297302
atoms = ase_read(
298303
self.data_input_dir / self.name / "starting_structures" / xyz_filename
299304
)
300305

301-
md_engine = get_simulation_engine(atoms, self.force_field, **md_kwargs)
306+
md_engine = get_simulation_engine(
307+
atoms, self.force_field, box=BOX_SIZES[structure_name], **md_kwargs
308+
)
302309
md_engine.run()
303310

304311
final_state = md_engine.state
@@ -317,6 +324,8 @@ def analyze(self) -> SamplingResult:
317324
if self.model_output is None:
318325
raise RuntimeError("Must call run_model() first.")
319326

327+
self._assert_structure_names_in_model_output()
328+
320329
systems = []
321330
skipped_systems = []
322331

@@ -359,7 +368,7 @@ def analyze(self) -> SamplingResult:
359368
continue
360369

361370
num_stable += 1
362-
box_size = CUBIC_BOX_SIZES[structure_name]
371+
box_size = BOX_SIZES[structure_name]
363372

364373
trajectory = create_mdtraj_trajectory_from_simulation_state(
365374
simulation_state,
@@ -369,7 +378,7 @@ def analyze(self) -> SamplingResult:
369378
/ "pdb_reference_structures"
370379
/ f"{structure_name}.pdb"
371380
),
372-
cell_lengths=(box_size, box_size, box_size),
381+
cell_lengths=box_size, # type: ignore
373382
)
374383

375384
dihedrals_data = get_all_dihedrals_from_trajectory(trajectory)
@@ -704,17 +713,24 @@ def _get_sampled_distributions(
704713

705714
unique_residue_names = set([residue.name for residue in dihedrals_data.keys()])
706715

716+
dihedrals_per_unique_name: dict[str, dict[str, np.ndarray]] = {}
717+
for residue, dihedrals in dihedrals_data.items():
718+
if residue.name not in dihedrals_per_unique_name:
719+
dihedrals_per_unique_name[residue.name] = defaultdict(list)
720+
for dihedral_type, angle_list in dihedrals.items():
721+
dihedrals_per_unique_name[residue.name][dihedral_type].extend(
722+
angle_list
723+
)
724+
707725
for residue_name in unique_residue_names:
708726
if not backbone:
709727
dihedral_keys = self._get_allowed_sidechain_dihedral_keys(residue_name)
710728
if len(dihedral_keys) == 0:
711729
continue
712730

713731
sampled_distributions[residue_name] = np.column_stack([
714-
dihedrals_data[residue][dihedral_key]
715-
for residue in dihedrals_data.keys()
732+
dihedrals_per_unique_name[residue_name][dihedral_key]
716733
for dihedral_key in dihedral_keys
717-
if residue.name == residue_name
718734
])
719735

720736
return sampled_distributions
@@ -779,3 +795,14 @@ def _average_over_residues(
779795
The average metrics.
780796
"""
781797
return np.mean(list(metrics_per_residue.values()))
798+
799+
def _assert_structure_names_in_model_output(self) -> None:
800+
"""Asserts whether model output structure names are fine as potentially they
801+
have been transferred from a different benchmark.
802+
"""
803+
assert set(self.model_output.structure_names).issubset(STRUCTURE_NAMES) # type: ignore
804+
assert len(self.model_output.structure_names) == ( # type: ignore
805+
1
806+
if self.run_mode == RunMode.DEV
807+
else (2 if self.run_mode == RunMode.FAST else len(STRUCTURE_NAMES))
808+
)

0 commit comments

Comments
 (0)