Skip to content

Commit 0e15714

Browse files
authored
feat: error handle inference and sim (#80)
* feat: add error handling to `run_inference` * feat: add error handling to NCI * feat: add error handling to BLD * feat: update tests * feat: update conformer selection * feat: add error handling in dihedral scan * feat: harmonize BLD, CS, DS * feat: add simulation error handling * feat: more consistent handling of DEV and FAST configurations * fix: incorrect stability dev config name * feat: add back `get_simulation_engine()` * fix: scaling benchmark engine run * feat: add error handling to reactivity * fix: some bugs * fix: some bugs * feat: replace tmpdir with tmp_path * feat: make failed a `Result` attribute * feat: add utility functions to filter out failed benchmarks * fix: broken imports * feat: more utils * feat: update BLD ui * feat: update conformer selection ui * feat: update dihedral scan ui * feat: update folding stability ui * feat: update noncovalent interactions ui * feat: update nudged elastic band ui * feat: update reactivity ui * feat: update reference geometry stability ui * feat: update ring planarity ui * feat: update sampling ui * refactor: write -> display * feat: typos * fix: scaling and ui * fix: nci result inheriting from `BenchmarkResult` * feat: update tests + update solvent rdf ui * feat: update stability ui * feat: update tautomers ui * feat: update water rdf ui * docs: add run_simulation * chore: remove prints * chore: remove prints * fix: nci * fix: nci test with failing structures * fix: nci ui bug
1 parent a5e6781 commit 0e15714

40 files changed

+1219
-473
lines changed

docs/source/api_reference/utils/inference_and_simulation.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ Inference and simulation helpers
1111

1212
.. autofunction:: get_simulation_engine
1313

14+
.. autofunction:: run_simulation
15+
1416
.. autoclass:: ASESimulationEngineWithCalculator

src/mlipaudit/benchmark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@ class BenchmarkResult(BaseModel):
3434
"""A base model for all benchmark results.
3535
3636
Attributes:
37+
failed: Whether all the simulations or inferences failed
38+
and no analysis could be performed. Defaults to False.
3739
score: The final score for the benchmark between
3840
0 and 1.
3941
"""
4042

43+
failed: bool = False
4144
score: float | None = Field(ge=0, le=1, default=None)
4245

4346

src/mlipaudit/benchmarks/bond_length_distribution/bond_length_distribution.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from mlipaudit.benchmark import Benchmark, BenchmarkResult, ModelOutput
2424
from mlipaudit.run_mode import RunMode
2525
from mlipaudit.scoring import compute_benchmark_score
26-
from mlipaudit.utils import get_simulation_engine
26+
from mlipaudit.utils import run_simulation
2727
from mlipaudit.utils.stability import is_simulation_stable
2828

2929
logger = logging.getLogger("mlipaudit")
@@ -37,12 +37,13 @@
3737
"temperature_kelvin": 300.0,
3838
}
3939

40-
SIMULATION_CONFIG_FAST = {
40+
SIMULATION_CONFIG_DEV = {
4141
"num_steps": 10,
4242
"snapshot_interval": 1,
4343
"num_episodes": 1,
4444
"temperature_kelvin": 300.0,
4545
}
46+
NUM_DEV_SYSTEMS = 2
4647

4748
DEVIATION_SCORE_THRESHOLD = 0.05
4849

@@ -75,13 +76,17 @@ class MoleculeModelOutput(BaseModel):
7576
7677
Attributes:
7778
molecule_name: The name of the molecule.
78-
simulation_state: The simulation state.
79+
simulation_state: The simulation state. Defaults to None
80+
if the simulation failed.
81+
failed: Whether the simulation failed on the molecule.
82+
Defaults to False.
7983
"""
8084

8185
model_config = ConfigDict(arbitrary_types_allowed=True)
8286

8387
molecule_name: str
84-
simulation_state: SimulationState
88+
simulation_state: SimulationState | None = None
89+
failed: bool = False
8590

8691

8792
class BondLengthDistributionModelOutput(ModelOutput):
@@ -90,9 +95,11 @@ class BondLengthDistributionModelOutput(ModelOutput):
9095
9196
Attributes:
9297
molecules: A list of simulation states for every molecule.
98+
num_failed: The number of molecules for which simulation failed.
9399
"""
94100

95101
molecules: list[MoleculeModelOutput]
102+
num_failed: int = 0
96103

97104

98105
class BondLengthDistributionMoleculeResult(BaseModel):
@@ -105,8 +112,8 @@ class BondLengthDistributionMoleculeResult(BaseModel):
105112
with each frame corresponding to 1ps of simulation time.
106113
avg_deviation: The average deviation of the molecule over the
107114
whole trajectory.
108-
failed: Whether the simulation was stable. If not stable, the other
109-
attributes will be not be set.
115+
failed: Whether the simulation succeeded and was stable. If not,
116+
the other attributes will default to None. Defaults to False.
110117
"""
111118

112119
molecule_name: str
@@ -122,7 +129,9 @@ class BondLengthDistributionResult(BenchmarkResult):
122129
Attributes:
123130
molecules: The individual results for each molecule in a list.
124131
avg_deviation: The average of the average deviations for each
125-
molecule that was stable. If no stable molecules, will be None.
132+
molecule that was stable. If the benchmark failed, will be None.
133+
failed: Whether all the simulations or inferences failed
134+
and no analysis could be performed. Defaults to False.
126135
score: The final score for the benchmark between
127136
0 and 1.
128137
"""
@@ -168,30 +177,36 @@ def run_model(self) -> None:
168177
the reference structure. The simulation state is stored in the
169178
`model_output` attribute.
170179
"""
171-
molecule_outputs = []
172-
173180
if self.run_mode == RunMode.DEV:
174-
md_kwargs = SIMULATION_CONFIG_FAST
181+
md_kwargs = SIMULATION_CONFIG_DEV
175182
else:
176183
md_kwargs = SIMULATION_CONFIG
177184

185+
molecule_outputs, num_failed = [], 0
186+
178187
for pattern_name, molecule in self._bond_length_distribution_data.items():
179188
logger.info("Running MD for %s", pattern_name)
180189

181190
atoms = Atoms(
182191
symbols=molecule.atom_symbols,
183192
positions=molecule.coordinates,
184193
)
185-
md_engine = get_simulation_engine(atoms, self.force_field, **md_kwargs)
186-
md_engine.run()
194+
simulation_state = run_simulation(atoms, self.force_field, **md_kwargs)
195+
196+
if simulation_state is not None:
197+
molecule_output = MoleculeModelOutput(
198+
molecule_name=pattern_name, simulation_state=simulation_state
199+
)
200+
else:
201+
molecule_output = MoleculeModelOutput(
202+
molecule_name=pattern_name, failed=True
203+
)
204+
num_failed += 1
187205

188-
molecule_output = MoleculeModelOutput(
189-
molecule_name=pattern_name, simulation_state=md_engine.state
190-
)
191206
molecule_outputs.append(molecule_output)
192207

193208
self.model_output = BondLengthDistributionModelOutput(
194-
molecules=molecule_outputs
209+
molecules=molecule_outputs, num_failed=num_failed
195210
)
196211

197212
def analyze(self) -> BondLengthDistributionResult:
@@ -210,19 +225,22 @@ def analyze(self) -> BondLengthDistributionResult:
210225
if self.model_output is None:
211226
raise RuntimeError("Must call run_model() first.")
212227

213-
results = []
214-
num_stable = 0
215-
for molecule_output in self.model_output.molecules:
216-
trajectory = molecule_output.simulation_state.positions
228+
results: list[BondLengthDistributionMoleculeResult] = []
229+
num_succeeded = 0
217230

218-
if not is_simulation_stable(molecule_output.simulation_state):
231+
for molecule_output in self.model_output.molecules:
232+
if molecule_output.failed or not is_simulation_stable(
233+
molecule_output.simulation_state
234+
):
219235
molecule_result = BondLengthDistributionMoleculeResult(
220236
molecule_name=molecule_output.molecule_name, failed=True
221237
)
222238
results.append(molecule_result)
223239
continue
224240

225-
num_stable += 1
241+
num_succeeded += 1
242+
243+
trajectory = molecule_output.simulation_state.positions
226244

227245
pattern_indices = self._bond_length_distribution_data[
228246
molecule_output.molecule_name
@@ -247,8 +265,10 @@ def analyze(self) -> BondLengthDistributionResult:
247265
)
248266
results.append(molecule_result)
249267

250-
if num_stable == 0:
251-
return BondLengthDistributionResult(molecules=results, score=0.0)
268+
if num_succeeded == 0:
269+
return BondLengthDistributionResult(
270+
molecules=results, failed=True, score=0.0
271+
)
252272

253273
avg_deviation = statistics.mean(
254274
r.avg_deviation for r in results if r.avg_deviation is not None
@@ -275,6 +295,6 @@ def _bond_length_distribution_data(self) -> dict[str, Molecule]:
275295
dataset = Molecules.validate_json(f.read())
276296

277297
if self.run_mode == RunMode.DEV:
278-
dataset = dict(list(dataset.items())[:2])
298+
dataset = dict(list(dataset.items())[:NUM_DEV_SYSTEMS])
279299

280300
return dataset

src/mlipaudit/benchmarks/conformer_selection/conformer_selection.py

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
logger = logging.getLogger("mlipaudit")
3131

3232
WIGGLE_DATASET_FILENAME = "wiggle150_dataset.json"
33+
NUM_DEV_SYSTEMS = 1
3334

3435
MAE_SCORE_THRESHOLD = 0.5
3536
RMSE_SCORE_THRESHOLD = 1.5
3637

3738

3839
class ConformerSelectionMoleculeResult(BaseModel):
3940
"""Results object for small molecule conformer selection benchmark for a single
40-
molecule.
41+
molecule. Will have attributes set to None if the inference failed.
4142
4243
Attributes:
4344
molecule_name: The molecule's name.
@@ -51,31 +52,37 @@ class ConformerSelectionMoleculeResult(BaseModel):
5152
and reference energy profiles.
5253
predicted_energy_profile: The predicted energy profile for each conformer.
5354
reference_energy_profile: The reference energy profiles for each conformer.
55+
failed: Whether the inference failed on the molecule.
5456
"""
5557

5658
molecule_name: str
57-
mae: NonNegativeFloat
58-
rmse: NonNegativeFloat
59-
spearman_correlation: float = Field(ge=-1.0, le=1.0)
60-
spearman_p_value: float = Field(ge=0.0, le=1.0)
61-
predicted_energy_profile: list[float]
62-
reference_energy_profile: list[float]
59+
mae: NonNegativeFloat | None = None
60+
rmse: NonNegativeFloat | None = None
61+
spearman_correlation: float | None = Field(ge=-1.0, le=1.0, default=None)
62+
spearman_p_value: float | None = Field(ge=0.0, le=1.0, default=None)
63+
predicted_energy_profile: list[float] | None = None
64+
reference_energy_profile: list[float] | None = None
65+
failed: bool = False
6366

6467

6568
class ConformerSelectionResult(BenchmarkResult):
6669
"""Results object for small molecule conformer selection benchmark.
6770
6871
Attributes:
6972
molecules: The individual results for each molecule in a list.
70-
avg_mae: The MAE values for all molecules averaged.
71-
avg_rmse: The RMSE values for all molecules averaged.
73+
avg_mae: The MAE values for all molecules that didn't fail averaged.
74+
Is None in the case all the inferences failed.
75+
avg_rmse: The RMSE values for all molecules that didn't fail averaged.
76+
Is None in the case all the inferences failed.
77+
failed: Whether all the simulations or inferences failed
78+
and no analysis could be performed. Defaults to False.
7279
score: The final score for the benchmark between
7380
0 and 1.
7481
"""
7582

7683
molecules: list[ConformerSelectionMoleculeResult]
77-
avg_mae: NonNegativeFloat
78-
avg_rmse: NonNegativeFloat
84+
avg_mae: NonNegativeFloat | None = None
85+
avg_rmse: NonNegativeFloat | None = None
7986

8087

8188
class ConformerSelectionMoleculeModelOutput(BaseModel):
@@ -84,20 +91,25 @@ class ConformerSelectionMoleculeModelOutput(BaseModel):
8491
Attributes:
8592
molecule_name: The molecule's name.
8693
predicted_energy_profile: The predicted energy profile for the conformers.
94+
Is None if the inference failed on the molecule.
95+
failed: Whether the inference failed on the molecule.
8796
"""
8897

8998
molecule_name: str
90-
predicted_energy_profile: list[float]
99+
predicted_energy_profile: list[float] | None = None
100+
failed: bool = False
91101

92102

93103
class ConformerSelectionModelOutput(ModelOutput):
94104
"""Stores model outputs for the conformer selection benchmark.
95105
96106
Attributes:
97107
molecules: Results for each molecule.
108+
num_failed: The number of molecules on which inference failed.
98109
"""
99110

100111
molecules: list[ConformerSelectionMoleculeModelOutput]
112+
num_failed: int = 0
101113

102114

103115
class Conformer(BaseModel):
@@ -160,7 +172,7 @@ def run_model(self) -> None:
160172
The calculation is performed as a batched inference using the MLIP force field
161173
directly. The energy profile is stored in the `model_output` attribute.
162174
"""
163-
molecule_outputs = []
175+
molecule_outputs, num_failed = [], 0
164176
for structure in self._wiggle150_data:
165177
logger.info("Running energy calculations for %s", structure.molecule_name)
166178

@@ -178,17 +190,23 @@ def run_model(self) -> None:
178190
batch_size=16,
179191
)
180192

181-
energy_profile_list: list[float] = [
182-
prediction.energy for prediction in predictions
183-
]
193+
if None in predictions:
194+
model_output = ConformerSelectionMoleculeModelOutput(
195+
molecule_name=structure.molecule_name, failed=True
196+
)
197+
num_failed += 1
184198

185-
model_output = ConformerSelectionMoleculeModelOutput(
186-
molecule_name=structure.molecule_name,
187-
predicted_energy_profile=energy_profile_list,
188-
)
199+
else:
200+
energy_profile_list = [prediction.energy for prediction in predictions] # type: ignore
201+
model_output = ConformerSelectionMoleculeModelOutput(
202+
molecule_name=structure.molecule_name,
203+
predicted_energy_profile=energy_profile_list,
204+
)
189205
molecule_outputs.append(model_output)
190206

191-
self.model_output = ConformerSelectionModelOutput(molecules=molecule_outputs)
207+
self.model_output = ConformerSelectionModelOutput(
208+
molecules=molecule_outputs, num_failed=num_failed
209+
)
192210

193211
def analyze(self) -> ConformerSelectionResult:
194212
"""Calculates the MAE, RMSE and Spearman correlation.
@@ -210,12 +228,21 @@ def analyze(self) -> ConformerSelectionResult:
210228
conformer.molecule_name: np.array(conformer.dft_energy_profile)
211229
for conformer in self._wiggle150_data
212230
}
213-
214231
results = []
232+
215233
for molecule in self.model_output.molecules:
216234
molecule_name = molecule.molecule_name
217-
energy_profile = molecule.predicted_energy_profile
218-
energy_profile = np.array(energy_profile)
235+
236+
if molecule.failed:
237+
results.append(
238+
ConformerSelectionMoleculeResult(
239+
molecule_name=molecule_name, failed=True
240+
)
241+
)
242+
continue
243+
244+
energy_profile = np.array(molecule.predicted_energy_profile)
245+
219246
ref_energy_profile = np.array(reference_energy_profiles[molecule_name])
220247

221248
min_ref_energy = np.min(ref_energy_profile)
@@ -251,8 +278,11 @@ def analyze(self) -> ConformerSelectionResult:
251278

252279
results.append(molecule_result)
253280

254-
avg_mae = statistics.mean(r.mae for r in results)
255-
avg_rmse = statistics.mean(r.rmse for r in results)
281+
if self.model_output.num_failed == len(self.model_output.molecules):
282+
return ConformerSelectionResult(molecules=results, failed=True, score=0.0)
283+
284+
avg_mae = statistics.mean(r.mae for r in results if r.mae is not None)
285+
avg_rmse = statistics.mean(r.rmse for r in results if r.rmse is not None)
256286

257287
score = compute_benchmark_score(
258288
[[r.mae for r in results], [r.rmse for r in results]],
@@ -276,6 +306,6 @@ def _wiggle150_data(self) -> list[Conformer]:
276306
wiggle150_data = Conformers.validate_json(f.read())
277307

278308
if self.run_mode == RunMode.DEV:
279-
wiggle150_data = wiggle150_data[:1]
309+
wiggle150_data = wiggle150_data[:NUM_DEV_SYSTEMS]
280310

281311
return wiggle150_data

0 commit comments

Comments
 (0)