1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import statistics
1516from pathlib import Path
1617from typing import Callable , TypeAlias
1718
@@ -38,14 +39,32 @@ def _process_data_into_dataframe(
3839 if model_name in selected_models :
3940 model_data_converted = {
4041 "Score" : results .score ,
41- "RMSE" : results .avg_rmse ,
42- "MAE" : results .avg_mae ,
42+ "Average RMSE (kcal/mol)" : results .avg_rmse ,
43+ "Average MAE (kcal/mol)" : results .avg_mae ,
44+ "Average Spearman correlation" : statistics .mean (
45+ r .spearman_correlation for r in results .molecules
46+ ),
4347 }
4448 converted_data_scores .append (model_data_converted )
4549
4650 return pd .DataFrame (converted_data_scores , index = selected_models )
4751
4852
53+ def _molecule_stats_df (results : ConformerSelectionResult ) -> pd .DataFrame :
54+ """Return a dataframe with per-molecule stats for a benchmark result."""
55+ rows = []
56+ for m in results .molecules :
57+ rows .append ({
58+ "Molecule" : m .molecule_name ,
59+ "MAE (kcal/mol)" : float (m .mae ),
60+ "RMSE (kcal/mol)" : float (m .rmse ),
61+ "Spearman" : float (m .spearman_correlation ),
62+ "Spearman p" : float (m .spearman_p_value ),
63+ })
64+ df = pd .DataFrame (rows ).set_index ("Molecule" )
65+ return df
66+
67+
4968def conformer_selection_page (
5069 data_func : Callable [[], BenchmarkResultForMultipleModels ],
5170) -> None :
@@ -91,9 +110,7 @@ def conformer_selection_page(
91110 with col3 :
92111 create_st_image (CONFORMER_IMG_DIR / "rsz_efa00.png" , "Efavirenz" )
93112
94- st .markdown ("" )
95113 st .markdown ("## Summary statistics" )
96- st .markdown ("" )
97114
98115 # Download data and get model names
99116 if "conformer_selection_cached_data" not in st .session_state :
@@ -112,6 +129,12 @@ def conformer_selection_page(
112129 model_select = st .sidebar .multiselect (
113130 "Select model(s)" , model_names , default = model_names
114131 )
132+ # with st.sidebar.container():
133+ # selected_energy_unit = st.selectbox(
134+ # "Select an energy unit:",
135+ # ["kcal/mol", "eV"],
136+ # )
137+
115138 selected_models = model_select if model_select else model_names
116139
117140 df = _process_data_into_dataframe (data , selected_models )
@@ -128,16 +151,13 @@ def conformer_selection_page(
128151 df .reset_index ()
129152 .melt (
130153 id_vars = ["index" ],
131- value_vars = ["RMSE" , "MAE" ],
154+ value_vars = ["Average RMSE (kcal/mol) " , "Average MAE (kcal/mol) " ],
132155 var_name = "Metric" ,
133156 value_name = "Value" ,
134157 )
135158 .rename (columns = {"index" : "Model" })
136159 )
137160
138- # Capitalize metric names for better display
139- chart_df ["Metric" ] = chart_df ["Metric" ].str .upper ()
140-
141161 # Create grouped bar chart
142162 chart = (
143163 alt .Chart (chart_df )
@@ -155,6 +175,124 @@ def conformer_selection_page(
155175
156176 st .altair_chart (chart , use_container_width = True )
157177
178+ # inside conformer_selection_page, add after the existing chart display
179+ st .markdown ("## Per-molecule statistics" )
180+ st .markdown (
181+ "Per-molecule MAE, RMSE and Spearman correlation for each selected model."
182+ )
183+
184+ for model_name in selected_models :
185+ results = data .get (model_name )
186+ if results is None :
187+ continue
188+
189+ st .markdown (f"### { model_name } " )
190+ mol_df = _molecule_stats_df (results )
191+
192+ # Display table
193+ st .dataframe (mol_df .round (4 ))
194+
195+ # Error chart (MAE and RMSE)
196+ error_chart_df = mol_df .reset_index ().melt (
197+ id_vars = ["Molecule" ],
198+ value_vars = ["MAE (kcal/mol)" , "RMSE (kcal/mol)" ],
199+ var_name = "Metric" ,
200+ value_name = "Value" ,
201+ )
202+ error_chart = (
203+ alt .Chart (error_chart_df )
204+ .mark_bar ()
205+ .encode (
206+ x = alt .X ("Molecule:N" , title = "Molecule" ),
207+ y = alt .Y ("Value:Q" , title = "Error (kcal/mol)" ),
208+ color = "Metric:N" ,
209+ xOffset = "Metric:N" ,
210+ )
211+ .properties (width = 600 , height = 250 )
212+ )
213+ st .altair_chart (error_chart , use_container_width = True )
214+
215+ # Plot correlation chart for a chosen molecule and model
216+
217+ # Create selectboxes for model and structure selection
218+ col1 , col2 = st .columns (2 )
219+ with col1 :
220+ selected_plot_model = st .selectbox (
221+ "Select model for plot:" , selected_models , key = "model_selector_plot"
222+ )
223+
224+ unique_structures = list (
225+ set ([mol .molecule_name for mol in data [selected_plot_model ].molecules ])
226+ )
227+
228+ with col2 :
229+ selected_structure = st .selectbox (
230+ "Select structure for plot:" ,
231+ unique_structures ,
232+ key = "structure_selector_plot" ,
233+ )
234+
235+ model_data_for_plot = [
236+ mol
237+ for mol in data [selected_plot_model ].molecules
238+ if mol .molecule_name == selected_structure
239+ ][0 ]
240+ scatter_data = []
241+ for pred_energy , ref_energy in zip (
242+ model_data_for_plot .predicted_energy_profile ,
243+ model_data_for_plot .reference_energy_profile ,
244+ ):
245+ scatter_data .append ({
246+ "Predicted Energy" : pred_energy ,
247+ "Reference Energy" : ref_energy ,
248+ })
249+
250+ structure_df = pd .DataFrame (scatter_data )
251+
252+ spearman_corr = structure_df ["Predicted Energy" ].corr (
253+ structure_df ["Reference Energy" ], method = "spearman"
254+ )
255+
256+ # Create scatter plot
257+ scatter_chart = (
258+ alt .Chart (structure_df )
259+ .mark_circle (size = 80 , opacity = 0.7 )
260+ .encode (
261+ x = alt .X ("Reference Energy:Q" , title = "Reference Energy (kcal/mol)" ),
262+ y = alt .Y ("Predicted Energy:Q" , title = "Inferred Energy (kcal/mol)" ),
263+ tooltip = ["Reference Energy:Q" , "Energy:Q" ],
264+ )
265+ .properties (
266+ width = 600 ,
267+ height = 400 ,
268+ title = (
269+ f"Model { selected_plot_model } - { selected_structure } "
270+ f"(Spearman ρ = { spearman_corr :.3f} )"
271+ ),
272+ )
273+ )
274+
275+ # Add diagonal line for perfect correlation
276+ min_energy = min (
277+ structure_df ["Reference Energy" ].min (), structure_df ["Predicted Energy" ].min ()
278+ )
279+ max_energy = max (
280+ structure_df ["Reference Energy" ].max (), structure_df ["Predicted Energy" ].max ()
281+ )
282+
283+ diagonal_line = (
284+ alt .Chart (
285+ pd .DataFrame ({"x" : [min_energy , max_energy ], "y" : [min_energy , max_energy ]})
286+ )
287+ .mark_line (color = "gray" , strokeDash = [5 , 5 ])
288+ .encode (x = "x:Q" , y = "y:Q" )
289+ )
290+
291+ # Combine scatter plot and diagonal line
292+ final_chart = scatter_chart + diagonal_line
293+
294+ st .altair_chart (final_chart , use_container_width = True )
295+
158296
159297class ConformerSelectionPageWrapper (UIPageWrapper ):
160298 """Page wrapper for conformer selection benchmark."""
0 commit comments