Skip to content

Commit f1f50ca

Browse files
lwalewleonwehrhanchrbrunk
authored
feat: add scoring and leaderboard (#44)
* chore: add changes * docs: revision (#43) * chore: delete spurious file * feat: remove sync-hf-space.yaml.yml * feat: don't lfs track .png * chore: move .png from lfs to regular git * chore: remove git lfs .gitattributes * feat: add score loading utility functions * feat: update app and leaderboard * feat: remove parse_scores * docs: add scores to api * feat: better handling of `all` in argparse * refactor: rename `remote` -> `is_public` * refactor: delete Dockerfile * feat: inverse ordering of args for `__hf` * feat: remove save model outputs option * feat: remove comment about scores * refactor: `Overall score` -> `overall_score` * feat: prettier printing benchmark names * feat: update the model and benchmark names * feat: cache slow resources in UI (#46) * feat: cache images in conformer selection ui page * feat: cache image in dihedral scan as well * fix: ppi 600 -> 300; bugfix with caching data * feat: more shared utils for ui pages * refactor: update_benchmark_names * fix: bugs with Overall score * fix: bugs with Overall score * feat: remove ugly colours * feat: add unit test, small other refactors, update docs * ci: remove jax-md from deps * ci: get ci running again * ci: get ci running again * fix: ui unit test should work now * test: increase timeout for ui page test --------- Co-authored-by: Leon Wehrhan <[email protected]> Co-authored-by: Christoph Brunken <[email protected]>
1 parent e4c6ddf commit f1f50ca

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+880
-317
lines changed

.github/workflows/deploy_docs.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ jobs:
3434

3535
- name: Build documentation
3636
run: |
37+
uv sync --group dev --group jax_md
3738
uv run sphinx-build -b html docs/source/ _build/
3839
3940
- name: Deploy to GitHub Pages

.github/workflows/tests_and_linters.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ jobs:
6363

6464
- name: Run tests 🧪
6565
run: |
66+
uv sync --group dev --group jax_md
6667
uv run pytest --verbose --cov-report xml:coverage.xml \
6768
--cov-report term-missing \
6869
--junitxml=pytest.xml \

docs/source/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Base classes and utilities
1414
benchmark
1515
io
1616
run_mode
17+
scoring
1718
utils/trajectory_helpers
1819

1920
Benchmark implementations

docs/source/api_reference/io.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ I/O of model outputs and benchmark results
1111

1212
.. autofunction:: load_benchmark_results_from_disk
1313

14+
.. autofunction:: write_scores_to_disk
15+
16+
.. autofunction:: load_score_from_disk
17+
18+
.. autofunction:: load_scores_from_disk
19+
1420
.. autofunction:: write_model_output_to_disk
1521

1622
.. autofunction:: load_model_output_from_disk
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
.. _scoring:
2+
3+
.. module:: mlipaudit.scoring
4+
5+
Scoring
6+
=======
7+
8+
.. autofunction:: compute_metric_score
9+
10+
.. autofunction:: compute_benchmark_score
-109 KB
Binary file not shown.

docs/source/tutorials/cli/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ The tool has the following command line options:
3939
list of benchmark names (e.g., ``dihedral_scan``, ``ring_planarity``) or ``all`` to
4040
run all available benchmarks which is also the default which means that if this flag
4141
is not used, all benchmarks will be run.
42-
* ``--run-mode``: *Optional* setting that allows to run faster versions of the
42+
* ``-rm / --run-mode``: *Optional* setting that allows to run faster versions of the
4343
benchmark suite. The default option ``standard`` which runs the entire suite.
4444
The option ``fast`` runs a slightly faster version for some of the very long-running
4545
benchmarks. The option ``dev`` runs a very minimal version of each benchmark for

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ dependencies = [
1616
"vl-convert-python>=1.8.0",
1717
"mdtraj>=1.10.3",
1818
"tmtools==0.2.0",
19-
"jax-md",
2019
]
2120

2221
[project.scripts]
@@ -45,6 +44,9 @@ gpu = [
4544
"jax[cuda12]==0.4.33",
4645
"jaxlib==0.4.33"
4746
]
47+
jax_md = [
48+
"jax-md",
49+
]
4850

4951
[tool.coverage.run]
5052
omit = [

src/mlipaudit/app.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from mlipaudit.conformer_selection import ConformerSelectionBenchmark
2727
from mlipaudit.dihedral_scan import DihedralScanBenchmark
2828
from mlipaudit.folding_stability import FoldingStabilityBenchmark
29-
from mlipaudit.io import load_benchmark_results_from_disk
29+
from mlipaudit.io import load_benchmark_results_from_disk, load_scores_from_disk
3030
from mlipaudit.noncovalent_interactions import NoncovalentInteractionsBenchmark
3131
from mlipaudit.reactivity import ReactivityBenchmark
3232
from mlipaudit.ring_planarity import RingPlanarityBenchmark
@@ -43,6 +43,7 @@
4343
conformer_selection_page,
4444
dihedral_scan_page,
4545
folding_stability_page,
46+
leaderboard_page,
4647
noncovalent_interactions_page,
4748
reactivity_page,
4849
ring_planarity_page,
@@ -54,6 +55,9 @@
5455
tautomers_page,
5556
water_radial_distribution_page,
5657
)
58+
from mlipaudit.ui.utils import (
59+
remove_model_name_extensions_and_capitalize_benchmark_names,
60+
)
5761
from mlipaudit.water_radial_distribution import (
5862
WaterRadialDistributionBenchmark,
5963
)
@@ -105,24 +109,40 @@ def main():
105109
"You must provide the results directory as a command line argument, "
106110
"like this: mlipauditapp /path/to/results"
107111
)
112+
is_public = False
113+
if len(sys.argv) == 3 and sys.argv[2] == "__hf":
114+
is_public = True
115+
else:
116+
if not Path(sys.argv[1]).exists():
117+
raise RuntimeError("The specified results directory does not exist.")
118+
119+
results_dir = sys.argv[1]
108120

109-
if not Path(sys.argv[1]).exists():
110-
raise RuntimeError("The specified results directory does not exist.")
121+
results = load_benchmark_results_from_disk(results_dir, BENCHMARKS)
122+
scores = load_scores_from_disk(scores_dir=results_dir)
111123

112-
data = load_benchmark_results_from_disk(sys.argv[1], BENCHMARKS)
124+
if is_public:
125+
remove_model_name_extensions_and_capitalize_benchmark_names(results)
126+
127+
leaderboard = st.Page(
128+
functools.partial(leaderboard_page, scores=scores, is_public=is_public),
129+
title="Leaderboard",
130+
icon=":material/trophy:",
131+
default=True,
132+
)
113133

114134
conformer_selection = st.Page(
115135
functools.partial(
116136
conformer_selection_page,
117-
data_func=_data_func_from_key("conformer_selection", data),
137+
data_func=_data_func_from_key("conformer_selection", results),
118138
),
119139
title="Conformer selection",
120140
url_path="conformer_selection",
121141
)
122142
dihedral_scan = st.Page(
123143
functools.partial(
124144
dihedral_scan_page,
125-
data_func=_data_func_from_key("dihedral_scan", data),
145+
data_func=_data_func_from_key("dihedral_scan", results),
126146
),
127147
title="Dihedral scan",
128148
url_path="dihedral_scan",
@@ -131,23 +151,23 @@ def main():
131151
tautomers = st.Page(
132152
functools.partial(
133153
tautomers_page,
134-
data_func=_data_func_from_key("tautomers", data),
154+
data_func=_data_func_from_key("tautomers", results),
135155
),
136156
title="Tautomers",
137157
url_path="tautomers",
138158
)
139159
noncovalent_interactions = st.Page(
140160
functools.partial(
141161
noncovalent_interactions_page,
142-
data_func=_data_func_from_key("noncovalent_interactions", data),
162+
data_func=_data_func_from_key("noncovalent_interactions", results),
143163
),
144164
title="Noncovalent Interactions",
145165
url_path="noncovalent_interactions",
146166
)
147167
ring_planarity = st.Page(
148168
functools.partial(
149169
ring_planarity_page,
150-
data_func=_data_func_from_key("ring_planarity", data),
170+
data_func=_data_func_from_key("ring_planarity", results),
151171
),
152172
title="Ring planarity",
153173
url_path="ring_planarity",
@@ -156,7 +176,7 @@ def main():
156176
small_molecule_minimization = st.Page(
157177
functools.partial(
158178
small_molecule_minimization_page,
159-
data_func=_data_func_from_key("small_molecule_minimization", data),
179+
data_func=_data_func_from_key("small_molecule_minimization", results),
160180
),
161181
title="Small molecule minimization",
162182
url_path="small_molecule_minimization",
@@ -165,7 +185,7 @@ def main():
165185
reactivity = st.Page(
166186
functools.partial(
167187
reactivity_page,
168-
data_func=_data_func_from_key("reactivity", data),
188+
data_func=_data_func_from_key("reactivity", results),
169189
),
170190
title="Reactivity",
171191
url_path="reactivity",
@@ -174,7 +194,7 @@ def main():
174194
folding_stability = st.Page(
175195
functools.partial(
176196
folding_stability_page,
177-
data_func=_data_func_from_key("folding_stability", data),
197+
data_func=_data_func_from_key("folding_stability", results),
178198
),
179199
title="Protein folding stability",
180200
url_path="protein_folding_stability",
@@ -183,7 +203,7 @@ def main():
183203
bond_length_distribution = st.Page(
184204
functools.partial(
185205
bond_length_distribution_page,
186-
data_func=_data_func_from_key("bond_length_distribution", data),
206+
data_func=_data_func_from_key("bond_length_distribution", results),
187207
),
188208
title="Bond length distribution",
189209
url_path="bond_length_distribution",
@@ -192,7 +212,7 @@ def main():
192212
sampling = st.Page(
193213
functools.partial(
194214
sampling_page,
195-
data_func=_data_func_from_key("sampling", data),
215+
data_func=_data_func_from_key("sampling", results),
196216
),
197217
title="Protein sampling",
198218
url_path="sampling",
@@ -201,7 +221,7 @@ def main():
201221
water_radial_distribution = st.Page(
202222
functools.partial(
203223
water_radial_distribution_page,
204-
data_func=_data_func_from_key("water_radial_distribution", data),
224+
data_func=_data_func_from_key("water_radial_distribution", results),
205225
),
206226
title="Water radial distribution function",
207227
url_path="water_radial_distribution_function",
@@ -210,7 +230,7 @@ def main():
210230
solvent_radial_distribution = st.Page(
211231
functools.partial(
212232
solvent_radial_distribution_page,
213-
data_func=_data_func_from_key("solvent_radial_distribution", data),
233+
data_func=_data_func_from_key("solvent_radial_distribution", results),
214234
),
215235
title="Solvent radial distribution",
216236
url_path="solvent_radial_distribution",
@@ -219,7 +239,7 @@ def main():
219239
stability = st.Page(
220240
functools.partial(
221241
stability_page,
222-
data_func=_data_func_from_key("stability", data),
242+
data_func=_data_func_from_key("stability", results),
223243
),
224244
title="Stability",
225245
url_path="stability",
@@ -228,7 +248,7 @@ def main():
228248
scaling = st.Page(
229249
functools.partial(
230250
scaling_page,
231-
data_func=_data_func_from_key("scaling", data),
251+
data_func=_data_func_from_key("scaling", results),
232252
),
233253
title="Scaling",
234254
url_path="scaling",
@@ -266,14 +286,14 @@ def main():
266286

267287
# Filter pages based on selection
268288
if selected_category == "All Categories":
269-
pages_to_show = (
289+
pages_to_show = [leaderboard] + (
270290
page_categories["Small Molecules"]
271291
+ page_categories["Biomolecules"]
272292
+ page_categories["General"]
273293
)
274294

275295
else:
276-
pages_to_show = page_categories[selected_category]
296+
pages_to_show = [leaderboard] + page_categories[selected_category]
277297

278298
# Set up navigation in main area
279299
pg = st.navigation(pages_to_show)

src/mlipaudit/benchmark.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ase import Atom
2222
from huggingface_hub import hf_hub_download
2323
from mlip.models import ForceField
24-
from pydantic import BaseModel
24+
from pydantic import BaseModel, Field
2525

2626
from mlipaudit.exceptions import ChemicalElementsMissingError
2727
from mlipaudit.run_mode import RunMode
@@ -30,7 +30,14 @@
3030

3131

3232
class BenchmarkResult(BaseModel):
33-
"""A base model for all benchmark results."""
33+
"""A base model for all benchmark results.
34+
35+
Attributes:
36+
score: The final score for the benchmark between
37+
0 and 1.
38+
"""
39+
40+
score: float | None = Field(ge=0, le=1, default=None)
3441

3542

3643
class ModelOutput(BaseModel):
@@ -212,7 +219,8 @@ def analyze(self) -> BenchmarkResult:
212219
213220
Subclasses must implement this method. This method
214221
processes the raw data generated from the generation step
215-
to compute final metrics.
222+
to compute final metrics. Subclasses are also responsible
223+
for computing the final score for the benchmark.
216224
217225
Returns:
218226
A class-specific instance of `BenchmarkResult`.

0 commit comments

Comments
 (0)