Skip to content

Commit 6223dce

Browse files
authored
feat: small UI changes (#98)
* feat: remove bold MLIP in reactivity.py * feat: rename Inferred->Predicted * feat: add units to BLD * feat: remove best model in BLD * fix: score not being dropped * feat: add units for DS * docs: annotate convergence_rate in NEB * feat: improve NEB table * feat: ensure CS model correct ordering * docs: better docstring for `display_model_scores()` * feat: add units to reactivity * feat: add units to ring planarity * feat: add units to SMM and fix tooltip * feat: add units to tautomers * feat: add units to folding stability * feat: add units to sampling * feat: add units to Solvent RDF * feat: add units to water RDF * feat: improvements to stability * feat: more improvements to stability * feat: slant models on x axis + consistent naming * feat: add space between units
1 parent fcc5718 commit 6223dce

18 files changed

+193
-174
lines changed

src/mlipaudit/benchmarks/nudged_elastic_band/nudged_elastic_band.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ class NEBResult(BenchmarkResult):
160160
Attributes:
161161
reaction_results: A dictionary of reaction results where
162162
the keys are the reaction identifiers.
163+
convergence_rate: The fraction of converged reactions.
163164
"""
164165

165166
reaction_results: list[NEBReactionResult]

src/mlipaudit/benchmarks/sampling/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,13 @@ def calculate_distribution_hellinger_distance(
117117
reference_hist: Reference histogram array
118118
sampled_hist: Sampled histogram array
119119
normalize: Whether to normalize histograms before comparison (only
120-
set this to False if the histograms are already normalized).
120+
set this to False if the histograms are already normalized).
121121
122122
Raises:
123123
ValueError: If the histograms have different shapes.
124124
125125
Returns:
126-
float: Hellinger distance between the two distributions
126+
The Hellinger distance between the two distributions
127127
"""
128128
if reference_hist.shape != sampled_hist.shape:
129129
raise ValueError("Histograms must have the same shape")

src/mlipaudit/ui/bond_length_distribution.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -92,31 +92,20 @@ def bond_length_distribution_page(
9292
distribution_data = [
9393
{
9494
"Model name": model_name,
95-
"Average deviation": result.avg_deviation,
95+
"Average deviation (Å)": result.avg_deviation,
9696
"Score": result.score,
9797
}
9898
for model_name, result in data.items()
9999
if model_name in selected_models
100100
]
101101

102+
st.markdown("## Summary statistics")
103+
102104
df = pd.DataFrame(distribution_data)
103105

104106
df.sort_values("Score", ascending=False, inplace=True)
105107
display_model_scores(df)
106108

107-
st.markdown("## Best model summary")
108-
109-
# Get best model
110-
best_model_row = df.loc[df["Score"].idxmax()]
111-
best_model_name = best_model_row["Model name"]
112-
113-
st.markdown(f"The best model is **{best_model_name}**.")
114-
115-
st.metric(
116-
"Total average deviation (absolute)",
117-
f"{float(best_model_row['Average deviation']):.3f}",
118-
)
119-
120109
st.markdown("## Bond length deviation distribution per model")
121110

122111
# Get all unique ring types from the data
@@ -154,7 +143,7 @@ def bond_length_distribution_page(
154143
.encode(
155144
x=alt.X(
156145
"Model name:N",
157-
title="Model name",
146+
title="Model",
158147
axis=alt.Axis(labelAngle=-45, labelLimit=100),
159148
),
160149
y=alt.Y(
@@ -164,7 +153,7 @@ def bond_length_distribution_page(
164153
),
165154
color=alt.Color(
166155
"Model name:N",
167-
title="Model name",
156+
title="Model",
168157
legend=alt.Legend(orient="top"),
169158
),
170159
)

src/mlipaudit/ui/conformer_selection.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ def conformer_selection_page(
115115
with col3:
116116
create_st_image(CONFORMER_IMG_DIR / "rsz_efa00.png", "Efavirenz")
117117

118-
st.markdown("## Summary statistics")
119-
120118
# Download data and get model names
121119
if "conformer_selection_cached_data" not in st.session_state:
122120
st.session_state.conformer_selection_cached_data = data_func()
@@ -137,13 +135,15 @@ def conformer_selection_page(
137135
return
138136

139137
df = _process_data_into_dataframe(data, selected_models)
138+
139+
st.markdown("## Summary statistics")
140+
140141
df_display = df.copy()
141142
df_display.index.name = "Model name"
142143
df_display.sort_values("Score", ascending=False, inplace=True)
143144
display_model_scores(df_display)
144145

145146
st.markdown("## MAE and RMSE per model")
146-
st.markdown("")
147147

148148
# Melt the dataframe to prepare for Altair chart
149149
chart_df = (
@@ -202,7 +202,11 @@ def conformer_selection_page(
202202
alt.Chart(error_chart_df)
203203
.mark_bar()
204204
.encode(
205-
x=alt.X("Molecule:N", title="Molecule"),
205+
x=alt.X(
206+
"Molecule:N",
207+
title="Molecule",
208+
axis=alt.Axis(labelAngle=-45, labelLimit=100),
209+
),
206210
y=alt.Y("Value:Q", title="Error (kcal/mol)"),
207211
color="Metric:N",
208212
xOffset="Metric:N",
@@ -258,7 +262,7 @@ def conformer_selection_page(
258262
.mark_circle(size=80, opacity=0.7)
259263
.encode(
260264
x=alt.X("Reference Energy:Q", title="Reference Energy (kcal/mol)"),
261-
y=alt.Y("Predicted Energy:Q", title="Inferred Energy (kcal/mol)"),
265+
y=alt.Y("Predicted Energy:Q", title="Predicted Energy (kcal/mol)"),
262266
tooltip=["Reference Energy:Q", "Energy:Q"],
263267
)
264268
.properties(

src/mlipaudit/ui/dihedral_scan.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,10 @@ def dihedral_scan_page(
150150
{
151151
"Model name": model_name,
152152
"Score": result.score,
153-
"MAE": result.avg_mae * conversion_factor,
154-
"RMSE": result.avg_rmse * conversion_factor,
155-
"Barrier Height Error": result.mae_barrier_height * conversion_factor,
153+
f"MAE ({selected_energy_unit})": result.avg_mae * conversion_factor,
154+
f"RMSE ({selected_energy_unit})": result.avg_rmse * conversion_factor,
155+
f"Barrier Height MAE ({selected_energy_unit})": result.mae_barrier_height
156+
* conversion_factor,
156157
"Pearson Correlation": result.avg_pearson_r,
157158
}
158159
for model_name, result in data.items()
@@ -174,20 +175,27 @@ def dihedral_scan_page(
174175

175176
st.markdown("## Mean barrier height error")
176177
df_barrier = df[df["Model name"].isin(selected_models)][
177-
["Model name", "Barrier Height Error"]
178+
["Model name", f"Barrier Height MAE ({selected_energy_unit})"]
178179
]
179180

180181
barrier_chart = (
181182
alt.Chart(df_barrier)
182183
.mark_bar()
183184
.encode(
184-
x=alt.X("Model name:N", title="Model ID"),
185+
x=alt.X(
186+
"Model name:N",
187+
title="Model",
188+
axis=alt.Axis(labelAngle=-45, labelLimit=100),
189+
),
185190
y=alt.Y(
186-
"Barrier Height Error:Q",
191+
f"Barrier Height MAE ({selected_energy_unit}):Q",
187192
title=f"Mean Barrier Height Error ({selected_energy_unit})",
188193
),
189-
color=alt.Color("Model name:N", title="Model ID"),
190-
tooltip=["Model name:N", "Barrier Height Error:Q"],
194+
color=alt.Color("Model name:N", title="Model"),
195+
tooltip=[
196+
alt.Tooltip("Model name:N", title="Model"),
197+
f"Barrier Height MAE ({selected_energy_unit}):Q",
198+
],
191199
)
192200
.properties(
193201
width=600,

src/mlipaudit/ui/folding_stability.py

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,37 +34,39 @@ def _data_to_dataframes(
3434
agg_data = []
3535

3636
for model_name, result in data.items():
37-
if model_name in selected_models:
38-
for molecule_result in result.molecules:
39-
if not molecule_result.failed:
40-
for idx in range(len(molecule_result.rmsd_trajectory)): # type: ignore
41-
rad_of_gyr_dev = (
42-
molecule_result.radius_of_gyration_deviation[idx] # type: ignore
43-
)
44-
plot_data.append({
45-
"Model": model_name,
46-
"Structure": molecule_result.structure_name,
47-
"Frame": idx,
48-
"RMSD": molecule_result.rmsd_trajectory[idx], # type: ignore
49-
"TM score": molecule_result.tm_score_trajectory[idx], # type: ignore
50-
"Rad of Gyr Dev": rad_of_gyr_dev,
51-
"DSSP match": molecule_result.match_secondary_structure[
52-
idx
53-
], # type: ignore
54-
})
55-
# Next line is to stay within max. line length below
56-
max_dev_rad_of_gyr = (
57-
molecule_result.max_abs_deviation_radius_of_gyration
58-
)
59-
agg_data.append({
60-
"Model": model_name,
61-
"Score": result.score,
62-
"Structure": molecule_result.structure_name,
63-
"avg. RMSD": molecule_result.avg_rmsd,
64-
"avg. TM score": molecule_result.avg_tm_score,
65-
"avg. DSSP match": molecule_result.avg_match,
66-
"max. abs. dev. Rad. of Gyr.": max_dev_rad_of_gyr,
67-
})
37+
if model_name not in selected_models:
38+
continue
39+
40+
for molecule_result in result.molecules:
41+
if molecule_result.failed:
42+
continue
43+
44+
for idx in range(len(molecule_result.rmsd_trajectory)): # type: ignore
45+
plot_data.append({
46+
"Model": model_name,
47+
"Structure": molecule_result.structure_name,
48+
"Frame": idx,
49+
"RMSD": molecule_result.rmsd_trajectory[idx], # type: ignore
50+
"TM score": molecule_result.tm_score_trajectory[idx], # type: ignore
51+
"Rad of Gyr Dev": molecule_result.radius_of_gyration_deviation[ # type: ignore
52+
idx
53+
],
54+
"DSSP match": molecule_result.match_secondary_structure[idx], # type: ignore
55+
})
56+
# Next line is to stay within max. line length below
57+
max_dev_rad_of_gyr = (
58+
molecule_result.max_abs_deviation_radius_of_gyration
59+
)
60+
agg_data.append({
61+
"Model": model_name,
62+
"Score": result.score,
63+
"Structure": molecule_result.structure_name,
64+
"Average RMSD (Å)": molecule_result.avg_rmsd,
65+
"Average TM score": molecule_result.avg_tm_score,
66+
"Average DSSP match": molecule_result.avg_match,
67+
"Maximum absolute deviation"
68+
" of the radius of gyration (Å)": max_dev_rad_of_gyr,
69+
})
6870

6971
df = pd.DataFrame(plot_data)
7072
df_agg = pd.DataFrame(agg_data)
@@ -91,10 +93,10 @@ def _transform_dataframes_for_visualization(
9193
df_agg_filtered.groupby("Model")
9294
.agg({
9395
"Score": "mean",
94-
"avg. RMSD": "mean",
95-
"avg. TM score": "mean",
96-
"avg. DSSP match": "mean",
97-
"max. abs. dev. Rad. of Gyr.": "mean",
96+
"Average RMSD (Å)": "mean",
97+
"Average TM score": "mean",
98+
"Average DSSP match": "mean",
99+
"Maximum absolute deviation of the radius of gyration (Å)": "mean",
98100
})
99101
.round(4)
100102
.reset_index()
@@ -114,27 +116,32 @@ def _transform_dataframes_for_visualization(
114116

115117
# Ensure numeric values for aggregation
116118
df_agg_filtered_numeric = df_agg_filtered.copy()
117-
df_agg_filtered_numeric["avg. RMSD"] = pd.to_numeric(
118-
df_agg_filtered_numeric["avg. RMSD"], errors="coerce"
119+
df_agg_filtered_numeric["Average RMSD (Å)"] = pd.to_numeric(
120+
df_agg_filtered_numeric["Average RMSD (Å)"], errors="coerce"
119121
)
120-
df_agg_filtered_numeric["avg. TM score"] = pd.to_numeric(
121-
df_agg_filtered_numeric["avg. TM score"], errors="coerce"
122+
df_agg_filtered_numeric["Average TM score"] = pd.to_numeric(
123+
df_agg_filtered_numeric["Average TM score"], errors="coerce"
122124
)
123-
df_agg_filtered_numeric["avg. DSSP match"] = pd.to_numeric(
124-
df_agg_filtered_numeric["avg. DSSP match"], errors="coerce"
125+
df_agg_filtered_numeric["Average DSSP match"] = pd.to_numeric(
126+
df_agg_filtered_numeric["Average DSSP match"], errors="coerce"
125127
)
126-
df_agg_filtered_numeric["max. abs. dev. Rad. of Gyr."] = pd.to_numeric(
127-
df_agg_filtered_numeric["max. abs. dev. Rad. of Gyr."], errors="coerce"
128+
df_agg_filtered_numeric[
129+
"Maximum absolute deviation of the radius of gyration (Å)"
130+
] = pd.to_numeric(
131+
df_agg_filtered_numeric[
132+
"Maximum absolute deviation of the radius of gyration (Å)"
133+
],
134+
errors="coerce",
128135
)
129136

130137
# Calculate averages across structures for each model
131138
avg_metrics = (
132139
df_agg_filtered_numeric.groupby("Model")
133140
.agg({
134-
"avg. RMSD": "mean",
135-
"avg. TM score": "mean",
136-
"avg. DSSP match": "mean",
137-
"max. abs. dev. Rad. of Gyr.": "mean",
141+
"Average RMSD (Å)": "mean",
142+
"Average TM score": "mean",
143+
"Average DSSP match": "mean",
144+
"Maximum absolute deviation of the radius of gyration (Å)": "mean",
138145
})
139146
.reset_index()
140147
)
@@ -146,10 +153,10 @@ def _transform_dataframes_for_visualization(
146153
metrics_long = avg_metrics.melt(
147154
id_vars=["Model"],
148155
value_vars=[
149-
"avg. RMSD",
150-
"avg. TM score",
151-
"avg. DSSP match",
152-
"max. abs. dev. Rad. of Gyr.",
156+
"Average RMSD (Å)",
157+
"Average TM score",
158+
"Average DSSP match",
159+
"Maximum absolute deviation of the radius of gyration (Å)",
153160
],
154161
var_name="Metric",
155162
value_name="Value",
@@ -236,10 +243,13 @@ def folding_stability_page(
236243
# the lower/closer to 0 the value the better
237244
# and one for the closer to 1 the value the better
238245
metrics_long_0 = metrics_long[
239-
metrics_long["Metric"].isin(["avg. RMSD", "max. abs. dev. Rad. of Gyr."])
246+
metrics_long["Metric"].isin([
247+
"Average RMSD (Å)",
248+
"Maximum absolute deviation of the radius of gyration (Å)",
249+
])
240250
].copy()
241251
metrics_long_1 = metrics_long[
242-
metrics_long["Metric"].isin(["avg. TM score", "avg. DSSP match"])
252+
metrics_long["Metric"].isin(["Average TM score", "Average DSSP match"])
243253
].copy()
244254
st.markdown("### RMSD and Radius of Gyration")
245255
# Create a grouped bar chart
@@ -253,7 +263,7 @@ def folding_stability_page(
253263
sort=None,
254264
axis=alt.Axis(labelAngle=-45, labelLimit=100),
255265
),
256-
y=alt.Y("Value:Q", title="Value"),
266+
y=alt.Y("Value:Q", title="Metric"),
257267
color=alt.Color("Metric:N", title="Metric"),
258268
xOffset=alt.XOffset("Metric:N"),
259269
tooltip=["Model:N", "Metric:N", "Value:Q"],

src/mlipaudit/ui/noncovalent_interactions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def noncovalent_interactions_page(
227227
)
228228

229229
# Drop the score for the rest of processing
230-
df_subset.drop(columns=["Score"])
230+
df_subset = df_subset.drop(columns=["Score"])
231231

232232
# Reshape dataframe for Altair plotting
233233
df_melted = (
@@ -251,9 +251,13 @@ def noncovalent_interactions_page(
251251
),
252252
x=alt.X("RMSE:Q", title="RMSE (kcal/mol)"),
253253
yOffset=alt.YOffset("Model name:N"),
254-
color=alt.Color("Model name:N", title="Model Name"),
254+
color=alt.Color("Model name:N", title="Model"),
255255
opacity=alt.condition(selection, alt.value(0.8), alt.value(0.3)),
256-
tooltip=["Model name:N", "Interaction type:N", "RMSE:Q"],
256+
tooltip=[
257+
alt.Tooltip("Model name:N", title="Model"),
258+
"Interaction type:N",
259+
"RMSE:Q",
260+
],
257261
)
258262
.resolve_scale(color="independent")
259263
.properties(

0 commit comments

Comments
 (0)