Skip to content

Commit fcc5718

Browse files
authored
feat: store model selection when switching pages (#96)
* feat: add `fetch_selected_models` This function is not yet stable and requires some iteration. * feat: remember selected models between pages * feat: remove spurious title in sidebar + some missing data fixes * feat: make the default all models * fix: failing tests * feat: final fix to model selection * test: update model selection import * feat: ensure model names in correct order * fix: identical model names
1 parent 1d9aefe commit fcc5718

18 files changed

+243
-208
lines changed

src/mlipaudit/app.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mlipaudit.io import load_benchmark_results_from_disk, load_scores_from_disk
2626
from mlipaudit.ui import leaderboard_page
2727
from mlipaudit.ui.page_wrapper import UIPageWrapper
28+
from mlipaudit.ui.utils import model_selection
2829

2930

3031
def _data_func_from_key(
@@ -156,6 +157,10 @@ def main() -> None:
156157
else:
157158
pages_to_show = [leaderboard] + page_categories[selected_category]
158159

160+
# Add model selection
161+
with st.sidebar:
162+
model_selection(unique_model_names=list(results.keys()))
163+
159164
# Set up navigation in main area
160165
pg = st.navigation(pages_to_show)
161166

src/mlipaudit/ui/bond_length_distribution.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
BondLengthDistributionResult,
2323
)
2424
from mlipaudit.ui.page_wrapper import UIPageWrapper
25-
from mlipaudit.ui.utils import display_model_scores
25+
from mlipaudit.ui.utils import display_model_scores, fetch_selected_models
2626

2727
ModelName: TypeAlias = str
2828
BenchmarkResultForMultipleModels: TypeAlias = dict[
@@ -53,7 +53,6 @@ def bond_length_distribution_page(
5353
keys and the benchmark results objects as values.
5454
"""
5555
st.markdown("# Bond length distribution")
56-
st.sidebar.markdown("# Bond length distribution")
5756

5857
st.markdown(
5958
"The benchmark runs short simulations of small molecules to check whether the "
@@ -84,11 +83,11 @@ def bond_length_distribution_page(
8483
st.markdown("**No results to display**.")
8584
return
8685

87-
unique_model_names = list(set(data.keys()))
88-
model_select = st.sidebar.multiselect(
89-
"Select model(s)", unique_model_names, default=unique_model_names
90-
)
91-
selected_models = model_select if model_select else unique_model_names
86+
selected_models = fetch_selected_models(available_models=list(data.keys()))
87+
88+
if not selected_models:
89+
st.markdown("**No results to display**.")
90+
return
9291

9392
distribution_data = [
9493
{

src/mlipaudit/ui/conformer_selection.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
from mlipaudit.benchmarks import ConformerSelectionBenchmark, ConformerSelectionResult
2424
from mlipaudit.ui.page_wrapper import UIPageWrapper
25-
from mlipaudit.ui.utils import create_st_image, display_model_scores
25+
from mlipaudit.ui.utils import (
26+
create_st_image,
27+
display_model_scores,
28+
fetch_selected_models,
29+
)
2630

2731
APP_DATA_DIR = Path(__file__).parent.parent / "app_data"
2832
CONFORMER_IMG_DIR = APP_DATA_DIR / "conformer_selection" / "img"
@@ -35,6 +39,7 @@ def _process_data_into_dataframe(
3539
selected_models: list[str],
3640
) -> pd.DataFrame:
3741
converted_data_scores = []
42+
model_names = []
3843
for model_name, results in data.items():
3944
if model_name in selected_models:
4045
model_data_converted = {
@@ -46,8 +51,9 @@ def _process_data_into_dataframe(
4651
),
4752
}
4853
converted_data_scores.append(model_data_converted)
54+
model_names.append(model_name)
4955

50-
return pd.DataFrame(converted_data_scores, index=selected_models)
56+
return pd.DataFrame(converted_data_scores, index=model_names)
5157

5258

5359
def _molecule_stats_df(results: ConformerSelectionResult) -> pd.DataFrame:
@@ -76,7 +82,6 @@ def conformer_selection_page(
7682
keys and the benchmark results objects as values.
7783
"""
7884
st.markdown("# Conformer selection")
79-
st.sidebar.markdown("# Conformer selection")
8085

8186
st.markdown(
8287
"Organic molecules are flexible and able to adopt multiple conformations. "
@@ -125,17 +130,11 @@ def conformer_selection_page(
125130
st.markdown("**No results to display**.")
126131
return
127132

128-
model_names = list(data.keys())
129-
model_select = st.sidebar.multiselect(
130-
"Select model(s)", model_names, default=model_names
131-
)
132-
# with st.sidebar.container():
133-
# selected_energy_unit = st.selectbox(
134-
# "Select an energy unit:",
135-
# ["kcal/mol", "eV"],
136-
# )
133+
selected_models = fetch_selected_models(available_models=list(data.keys()))
137134

138-
selected_models = model_select if model_select else model_names
135+
if not selected_models:
136+
st.markdown("**No results to display**.")
137+
return
139138

140139
df = _process_data_into_dataframe(data, selected_models)
141140
df_display = df.copy()

src/mlipaudit/ui/dihedral_scan.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828
DihedralScanResult,
2929
)
3030
from mlipaudit.ui.page_wrapper import UIPageWrapper
31-
from mlipaudit.ui.utils import create_st_image, display_model_scores
31+
from mlipaudit.ui.utils import (
32+
create_st_image,
33+
display_model_scores,
34+
fetch_selected_models,
35+
)
3236

3337
APP_DATA_DIR = Path(__file__).parent.parent / "app_data"
3438
DIHEDRAL_SCAN_DATA_DIR = APP_DATA_DIR / "dihedral_scan"
@@ -88,7 +92,6 @@ def dihedral_scan_page(
8892
keys and the benchmark results objects as values.
8993
"""
9094
st.markdown("# Dihedral scan")
91-
st.sidebar.markdown("# Dihedral scan")
9295

9396
with st.sidebar.container():
9497
selected_energy_unit = st.selectbox(
@@ -134,11 +137,11 @@ def dihedral_scan_page(
134137
st.markdown("**No results to display**.")
135138
return
136139

137-
unique_model_names = list(set(data.keys()))
138-
model_select = st.sidebar.multiselect(
139-
"Select model(s)", unique_model_names, default=unique_model_names
140-
)
141-
selected_models = model_select if model_select else unique_model_names
140+
selected_models = fetch_selected_models(available_models=list(data.keys()))
141+
142+
if not selected_models:
143+
st.markdown("**No results to display**.")
144+
return
142145

143146
conversion_factor = (
144147
1.0 if selected_energy_unit == "kcal/mol" else (units.kcal / units.mol)
@@ -153,6 +156,7 @@ def dihedral_scan_page(
153156
"Pearson Correlation": result.avg_pearson_r,
154157
}
155158
for model_name, result in data.items()
159+
if model_name in selected_models
156160
]
157161

158162
# Create summary dataframe

src/mlipaudit/ui/folding_stability.py

Lines changed: 38 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,51 @@
2020

2121
from mlipaudit.benchmarks import FoldingStabilityBenchmark, FoldingStabilityResult
2222
from mlipaudit.ui.page_wrapper import UIPageWrapper
23-
from mlipaudit.ui.utils import display_model_scores
23+
from mlipaudit.ui.utils import display_model_scores, fetch_selected_models
2424

2525
ModelName: TypeAlias = str
2626
BenchmarkResultForMultipleModels: TypeAlias = dict[ModelName, FoldingStabilityResult]
2727

2828

2929
def _data_to_dataframes(
3030
data: BenchmarkResultForMultipleModels,
31+
selected_models: list[str],
3132
) -> tuple[pd.DataFrame, pd.DataFrame]:
3233
plot_data = []
3334
agg_data = []
3435

3536
for model_name, result in data.items():
36-
for molecule_result in result.molecules:
37-
if not molecule_result.failed:
38-
for idx in range(len(molecule_result.rmsd_trajectory)): # type: ignore
39-
plot_data.append({
40-
"Model": model_name,
41-
"Structure": molecule_result.structure_name,
42-
"Frame": idx,
43-
"RMSD": molecule_result.rmsd_trajectory[idx], # type: ignore
44-
"TM score": molecule_result.tm_score_trajectory[idx], # type: ignore
45-
"Rad of Gyr Dev": molecule_result.radius_of_gyration_deviation[ # type: ignore
46-
idx
47-
],
48-
"DSSP match": molecule_result.match_secondary_structure[idx], # type: ignore
49-
})
50-
# Next line is to stay within max. line length below
51-
max_dev_rad_of_gyr = (
52-
molecule_result.max_abs_deviation_radius_of_gyration
53-
)
54-
agg_data.append({
55-
"Model": model_name,
56-
"Score": result.score,
57-
"Structure": molecule_result.structure_name,
58-
"avg. RMSD": molecule_result.avg_rmsd,
59-
"avg. TM score": molecule_result.avg_tm_score,
60-
"avg. DSSP match": molecule_result.avg_match,
61-
"max. abs. dev. Rad. of Gyr.": max_dev_rad_of_gyr,
62-
})
37+
if model_name in selected_models:
38+
for molecule_result in result.molecules:
39+
if not molecule_result.failed:
40+
for idx in range(len(molecule_result.rmsd_trajectory)): # type: ignore
41+
rad_of_gyr_dev = (
42+
molecule_result.radius_of_gyration_deviation[idx] # type: ignore
43+
)
44+
plot_data.append({
45+
"Model": model_name,
46+
"Structure": molecule_result.structure_name,
47+
"Frame": idx,
48+
"RMSD": molecule_result.rmsd_trajectory[idx], # type: ignore
49+
"TM score": molecule_result.tm_score_trajectory[idx], # type: ignore
50+
"Rad of Gyr Dev": rad_of_gyr_dev,
51+
"DSSP match": molecule_result.match_secondary_structure[
52+
idx
53+
], # type: ignore
54+
})
55+
# Next line is to stay within max. line length below
56+
max_dev_rad_of_gyr = (
57+
molecule_result.max_abs_deviation_radius_of_gyration
58+
)
59+
agg_data.append({
60+
"Model": model_name,
61+
"Score": result.score,
62+
"Structure": molecule_result.structure_name,
63+
"avg. RMSD": molecule_result.avg_rmsd,
64+
"avg. TM score": molecule_result.avg_tm_score,
65+
"avg. DSSP match": molecule_result.avg_match,
66+
"max. abs. dev. Rad. of Gyr.": max_dev_rad_of_gyr,
67+
})
6368

6469
df = pd.DataFrame(plot_data)
6570
df_agg = pd.DataFrame(agg_data)
@@ -181,7 +186,6 @@ def folding_stability_page(
181186
keys and the benchmark results objects as values.
182187
"""
183188
st.markdown("# Folding stability of trajectories")
184-
st.sidebar.markdown("# Folding stability")
185189

186190
st.markdown(
187191
"This module examines the folding stability trajectories of proteins in MLIP "
@@ -209,21 +213,13 @@ def folding_stability_page(
209213
st.markdown("**No results to display**.")
210214
return
211215

212-
df, df_agg = _data_to_dataframes(data)
213-
214-
unique_model_ids = list(data.keys())
215-
216-
# Add "Select All" option
217-
all_models_option = st.sidebar.checkbox("Select all models", value=False)
216+
selected_models = fetch_selected_models(available_models=list(data.keys()))
218217

219-
if all_models_option:
220-
model_select = unique_model_ids
221-
else:
222-
model_select = st.sidebar.multiselect(
223-
"Select model(s)", unique_model_ids, default=unique_model_ids
224-
)
218+
if not selected_models:
219+
st.markdown("**No results to display**.")
220+
return
225221

226-
selected_models = model_select if model_select else unique_model_ids
222+
df, df_agg = _data_to_dataframes(data, selected_models)
227223

228224
unique_structures = list(set(df["Structure"].unique()))
229225

0 commit comments

Comments
 (0)