Skip to content

Commit 255f838

Browse files
authored
feat: update conformer selection (#77)
* feat: display per model per molecule stats * feat: display pearson correlation in summary * feat: remove spurious markdown * feat: add correlation plot
1 parent f322903 commit 255f838

File tree

1 file changed

+146
-8
lines changed

1 file changed

+146
-8
lines changed

src/mlipaudit/ui/conformer_selection.py

Lines changed: 146 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import statistics
1516
from pathlib import Path
1617
from typing import Callable, TypeAlias
1718

@@ -38,14 +39,32 @@ def _process_data_into_dataframe(
3839
if model_name in selected_models:
3940
model_data_converted = {
4041
"Score": results.score,
41-
"RMSE": results.avg_rmse,
42-
"MAE": results.avg_mae,
42+
"Average RMSE (kcal/mol)": results.avg_rmse,
43+
"Average MAE (kcal/mol)": results.avg_mae,
44+
"Average Spearman correlation": statistics.mean(
45+
r.spearman_correlation for r in results.molecules
46+
),
4347
}
4448
converted_data_scores.append(model_data_converted)
4549

4650
return pd.DataFrame(converted_data_scores, index=selected_models)
4751

4852

53+
def _molecule_stats_df(results: ConformerSelectionResult) -> pd.DataFrame:
54+
"""Return a dataframe with per-molecule stats for a benchmark result."""
55+
rows = []
56+
for m in results.molecules:
57+
rows.append({
58+
"Molecule": m.molecule_name,
59+
"MAE (kcal/mol)": float(m.mae),
60+
"RMSE (kcal/mol)": float(m.rmse),
61+
"Spearman": float(m.spearman_correlation),
62+
"Spearman p": float(m.spearman_p_value),
63+
})
64+
df = pd.DataFrame(rows).set_index("Molecule")
65+
return df
66+
67+
4968
def conformer_selection_page(
5069
data_func: Callable[[], BenchmarkResultForMultipleModels],
5170
) -> None:
@@ -91,9 +110,7 @@ def conformer_selection_page(
91110
with col3:
92111
create_st_image(CONFORMER_IMG_DIR / "rsz_efa00.png", "Efavirenz")
93112

94-
st.markdown("")
95113
st.markdown("## Summary statistics")
96-
st.markdown("")
97114

98115
# Download data and get model names
99116
if "conformer_selection_cached_data" not in st.session_state:
@@ -112,6 +129,12 @@ def conformer_selection_page(
112129
model_select = st.sidebar.multiselect(
113130
"Select model(s)", model_names, default=model_names
114131
)
132+
# with st.sidebar.container():
133+
# selected_energy_unit = st.selectbox(
134+
# "Select an energy unit:",
135+
# ["kcal/mol", "eV"],
136+
# )
137+
115138
selected_models = model_select if model_select else model_names
116139

117140
df = _process_data_into_dataframe(data, selected_models)
@@ -128,16 +151,13 @@ def conformer_selection_page(
128151
df.reset_index()
129152
.melt(
130153
id_vars=["index"],
131-
value_vars=["RMSE", "MAE"],
154+
value_vars=["Average RMSE (kcal/mol)", "Average MAE (kcal/mol)"],
132155
var_name="Metric",
133156
value_name="Value",
134157
)
135158
.rename(columns={"index": "Model"})
136159
)
137160

138-
# Capitalize metric names for better display
139-
chart_df["Metric"] = chart_df["Metric"].str.upper()
140-
141161
# Create grouped bar chart
142162
chart = (
143163
alt.Chart(chart_df)
@@ -155,6 +175,124 @@ def conformer_selection_page(
155175

156176
st.altair_chart(chart, use_container_width=True)
157177

178+
# inside conformer_selection_page, add after the existing chart display
179+
st.markdown("## Per-molecule statistics")
180+
st.markdown(
181+
"Per-molecule MAE, RMSE and Spearman correlation for each selected model."
182+
)
183+
184+
for model_name in selected_models:
185+
results = data.get(model_name)
186+
if results is None:
187+
continue
188+
189+
st.markdown(f"### {model_name}")
190+
mol_df = _molecule_stats_df(results)
191+
192+
# Display table
193+
st.dataframe(mol_df.round(4))
194+
195+
# Error chart (MAE and RMSE)
196+
error_chart_df = mol_df.reset_index().melt(
197+
id_vars=["Molecule"],
198+
value_vars=["MAE (kcal/mol)", "RMSE (kcal/mol)"],
199+
var_name="Metric",
200+
value_name="Value",
201+
)
202+
error_chart = (
203+
alt.Chart(error_chart_df)
204+
.mark_bar()
205+
.encode(
206+
x=alt.X("Molecule:N", title="Molecule"),
207+
y=alt.Y("Value:Q", title="Error (kcal/mol)"),
208+
color="Metric:N",
209+
xOffset="Metric:N",
210+
)
211+
.properties(width=600, height=250)
212+
)
213+
st.altair_chart(error_chart, use_container_width=True)
214+
215+
# Plot correlation chart for a chosen molecule and model
216+
217+
# Create selectboxes for model and structure selection
218+
col1, col2 = st.columns(2)
219+
with col1:
220+
selected_plot_model = st.selectbox(
221+
"Select model for plot:", selected_models, key="model_selector_plot"
222+
)
223+
224+
unique_structures = list(
225+
set([mol.molecule_name for mol in data[selected_plot_model].molecules])
226+
)
227+
228+
with col2:
229+
selected_structure = st.selectbox(
230+
"Select structure for plot:",
231+
unique_structures,
232+
key="structure_selector_plot",
233+
)
234+
235+
model_data_for_plot = [
236+
mol
237+
for mol in data[selected_plot_model].molecules
238+
if mol.molecule_name == selected_structure
239+
][0]
240+
scatter_data = []
241+
for pred_energy, ref_energy in zip(
242+
model_data_for_plot.predicted_energy_profile,
243+
model_data_for_plot.reference_energy_profile,
244+
):
245+
scatter_data.append({
246+
"Predicted Energy": pred_energy,
247+
"Reference Energy": ref_energy,
248+
})
249+
250+
structure_df = pd.DataFrame(scatter_data)
251+
252+
spearman_corr = structure_df["Predicted Energy"].corr(
253+
structure_df["Reference Energy"], method="spearman"
254+
)
255+
256+
# Create scatter plot
257+
scatter_chart = (
258+
alt.Chart(structure_df)
259+
.mark_circle(size=80, opacity=0.7)
260+
.encode(
261+
x=alt.X("Reference Energy:Q", title="Reference Energy (kcal/mol)"),
262+
y=alt.Y("Predicted Energy:Q", title="Inferred Energy (kcal/mol)"),
263+
tooltip=["Reference Energy:Q", "Energy:Q"],
264+
)
265+
.properties(
266+
width=600,
267+
height=400,
268+
title=(
269+
f"Model {selected_plot_model} - {selected_structure} "
270+
f"(Spearman ρ = {spearman_corr:.3f})"
271+
),
272+
)
273+
)
274+
275+
# Add diagonal line for perfect correlation
276+
min_energy = min(
277+
structure_df["Reference Energy"].min(), structure_df["Predicted Energy"].min()
278+
)
279+
max_energy = max(
280+
structure_df["Reference Energy"].max(), structure_df["Predicted Energy"].max()
281+
)
282+
283+
diagonal_line = (
284+
alt.Chart(
285+
pd.DataFrame({"x": [min_energy, max_energy], "y": [min_energy, max_energy]})
286+
)
287+
.mark_line(color="gray", strokeDash=[5, 5])
288+
.encode(x="x:Q", y="y:Q")
289+
)
290+
291+
# Combine scatter plot and diagonal line
292+
final_chart = scatter_chart + diagonal_line
293+
294+
st.altair_chart(final_chart, use_container_width=True)
295+
158296

159297
class ConformerSelectionPageWrapper(UIPageWrapper):
160298
"""Page wrapper for conformer selection benchmark."""

0 commit comments

Comments
 (0)