Skip to content

Commit 6f49935

Browse files
add deeplc plot to plotting module
1 parent 95e149e commit 6f49935

File tree

2 files changed

+173
-32
lines changed

2 files changed

+173
-32
lines changed

ms2rescore/report/charts.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""Collection of Plotly-based charts for reporting results of MS²Rescore."""
22

3+
import importlib.resources
34
import warnings
45
from collections import defaultdict
6+
from pathlib import Path
57
from typing import Dict, List, Optional, Tuple, Union
68

79
import mokapot
@@ -635,3 +637,155 @@ def feature_ecdf_auc_bar(
635637
},
636638
color_discrete_map=color_discrete_map,
637639
)
640+
641+
642+
def rt_scatter(
643+
df: pd.DataFrame,
644+
predicted_column: str = "Predicted retention time",
645+
observed_column: str = "Observed retention time",
646+
xaxis_label: str = "Observed retention time",
647+
yaxis_label: str = "Predicted retention time",
648+
plot_title: str = "Predicted vs. observed retention times",
649+
) -> go.Figure:
650+
"""
651+
Plot a scatter plot of the predicted vs. observed retention times.
652+
653+
Parameters
654+
----------
655+
df : pd.DataFrame
656+
Dataframe containing the predicted and observed retention times.
657+
predicted_column : str, optional
658+
Name of the column containing the predicted retention times, by default
659+
``Predicted retention time``.
660+
observed_column : str, optional
661+
Name of the column containing the observed retention times, by default
662+
``Observed retention time``.
663+
xaxis_label : str, optional
664+
X-axis label, by default ``Observed retention time``.
665+
yaxis_label : str, optional
666+
Y-axis label, by default ``Predicted retention time``.
667+
plot_title : str, optional
668+
Scatter plot title, by default ``Predicted vs. observed retention times``
669+
670+
"""
671+
# Draw scatter
672+
fig = px.scatter(
673+
df,
674+
x=observed_column,
675+
y=predicted_column,
676+
opacity=0.3,
677+
)
678+
679+
# Draw diagonal line
680+
fig.add_scatter(
681+
x=[min(df[observed_column]), max(df[observed_column])],
682+
y=[min(df[observed_column]), max(df[observed_column])],
683+
mode="lines",
684+
line=dict(color="red", width=3, dash="dash"),
685+
)
686+
687+
# Hide legend
688+
fig.update_layout(
689+
title=plot_title,
690+
showlegend=False,
691+
xaxis_title=xaxis_label,
692+
yaxis_title=yaxis_label,
693+
)
694+
695+
return fig
696+
697+
698+
def rt_distribution_baseline(
699+
df: pd.DataFrame,
700+
predicted_column: str = "Predicted retention time",
701+
observed_column: str = "Observed retention time",
702+
) -> go.Figure:
703+
"""
704+
Plot a distribution plot of the relative mean absolute error of the current
705+
DeepLC performance compared to the baseline performance.
706+
707+
Parameters
708+
----------
709+
df : pd.DataFrame
710+
Dataframe containing the predicted and observed retention times.
711+
predicted_column : str, optional
712+
Name of the column containing the predicted retention times, by default
713+
``Predicted retention time``.
714+
observed_column : str, optional
715+
Name of the column containing the observed retention times, by default
716+
``Observed retention time``.
717+
718+
"""
719+
# Get baseline data from deeplc package
720+
try:
721+
import deeplc.package_data
722+
723+
baseline_path = (
724+
Path(importlib.resources.files(deeplc.package_data))
725+
/ "baseline_performance"
726+
/ "baseline_predictions.csv"
727+
)
728+
baseline_df = pd.read_csv(baseline_path)
729+
except (ImportError, FileNotFoundError):
730+
# If deeplc is not installed or baseline data not found, return empty figure
731+
fig = go.Figure()
732+
fig.add_annotation(
733+
text="DeepLC baseline data not available. Install DeepLC to view performance comparison.",
734+
showarrow=False,
735+
)
736+
return fig
737+
738+
baseline_df["rel_mae_best"] = baseline_df[
739+
["rel_mae_transfer_learning", "rel_mae_new_model", "rel_mae_calibrate"]
740+
].min(axis=1)
741+
baseline_df.fillna(0.0, inplace=True)
742+
743+
# Calculate current RMAE and percentile compared to baseline
744+
mae = sum(abs(df[observed_column] - df[predicted_column])) / len(df.index)
745+
mae_rel = (mae / max(df[observed_column])) * 100
746+
percentile = round((baseline_df["rel_mae_transfer_learning"] < mae_rel).mean() * 100, 1)
747+
748+
# Calculate x-axis range with 5% padding
749+
all_values = np.append(baseline_df["rel_mae_transfer_learning"].values, mae_rel)
750+
padding = (all_values.max() - all_values.min()) / 20 # 5% padding
751+
x_min = all_values.min() - padding
752+
x_max = all_values.max() + padding
753+
754+
# Make labels human-readable
755+
hover_label_mapping = {
756+
"train_number": "Training dataset size",
757+
"rel_mae_transfer_learning": "RMAE with transfer learning",
758+
"rel_mae_new_model": "RMAE with new model from scratch",
759+
"rel_mae_calibrate": "RMAE with calibrating existing model",
760+
"rel_mae_best": "RMAE with best method",
761+
}
762+
label_mapping = hover_label_mapping.copy()
763+
label_mapping.update({"Unnamed: 0": "Dataset"})
764+
765+
# Generate plot
766+
fig = px.histogram(
767+
data_frame=baseline_df,
768+
x="rel_mae_best",
769+
marginal="rug",
770+
hover_data=hover_label_mapping.keys(),
771+
hover_name="Unnamed: 0",
772+
labels=label_mapping,
773+
opacity=0.8,
774+
)
775+
fig.add_vline(
776+
x=mae_rel,
777+
line_width=3,
778+
line_dash="dash",
779+
line_color="red",
780+
annotation_text=f"Current performance (percentile {percentile}%)",
781+
annotation_position="top left",
782+
name="Current performance",
783+
row=1,
784+
)
785+
fig.update_xaxes(range=[x_min, x_max])
786+
fig.update_layout(
787+
title=(f"Current DeepLC performance compared to {len(baseline_df.index)} datasets"),
788+
xaxis_title="Relative mean absolute error (%)",
789+
)
790+
791+
return fig

ms2rescore/report/generate.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -322,14 +322,12 @@ def _get_features_context(
322322

323323
# DeepLC specific charts
324324
if "deeplc" in feature_names:
325-
import deeplc.plot
326-
327-
scatter_chart = deeplc.plot.scatter(
325+
scatter_chart = charts.rt_scatter(
328326
df=features[(~psm_list["is_decoy"]) & (psm_list["qvalue"] <= 0.01)],
329327
predicted_column="predicted_retention_time_best",
330328
observed_column="observed_retention_time_best",
331329
)
332-
baseline_chart = deeplc.plot.distribution_baseline(
330+
baseline_chart = charts.rt_distribution_baseline(
333331
df=features[(~psm_list["is_decoy"]) & (psm_list["qvalue"] <= 0.01)],
334332
predicted_column="predicted_retention_time_best",
335333
observed_column="observed_retention_time_best",
@@ -345,9 +343,7 @@ def _get_features_context(
345343

346344
# IM2Deep specific charts
347345
if "im2deep" in feature_names:
348-
import deeplc.plot
349-
350-
scatter_chart = deeplc.plot.scatter(
346+
scatter_chart = charts.rt_scatter(
351347
df=features[(~psm_list["is_decoy"]) & (psm_list["qvalue"] <= 0.01)],
352348
predicted_column="ccs_predicted_im2deep",
353349
observed_column="ccs_observed_im2deep",
@@ -366,31 +362,22 @@ def _get_features_context(
366362

367363
# ionmob specific charts
368364
if "ionmob" in feature_names:
369-
try:
370-
import deeplc.plot
371-
372-
scatter_chart = deeplc.plot.scatter(
373-
df=features[(~psm_list["is_decoy"]) & (psm_list["qvalue"] <= 0.01)],
374-
predicted_column="ccs_predicted",
375-
observed_column="ccs_observed",
376-
xaxis_label="Observed CCS",
377-
yaxis_label="Predicted CCS",
378-
plot_title="Predicted vs. observed CCS - ionmob",
379-
)
380-
381-
context["charts"].append(
382-
{
383-
"title": TEXTS["charts"]["ionmob_performance"]["title"],
384-
"description": TEXTS["charts"]["ionmob_performance"]["description"],
385-
"chart": scatter_chart.to_html(**PLOTLY_HTML_KWARGS),
386-
}
387-
)
388-
389-
# TODO: for now, ionmob plot will only be available if deeplc is installed. Since ionmob does not have a dependency on deeplc, this should be changed in the future.
390-
except ImportError:
391-
logger.warning(
392-
"Could not import deeplc.plot, skipping ionmob CCS prediction performance plot. Please install DeepLC to generate this plot."
393-
)
365+
scatter_chart = charts.rt_scatter(
366+
df=features[(~psm_list["is_decoy"]) & (psm_list["qvalue"] <= 0.01)],
367+
predicted_column="ccs_predicted",
368+
observed_column="ccs_observed",
369+
xaxis_label="Observed CCS",
370+
yaxis_label="Predicted CCS",
371+
plot_title="Predicted vs. observed CCS - ionmob",
372+
)
373+
374+
context["charts"].append(
375+
{
376+
"title": TEXTS["charts"]["ionmob_performance"]["title"],
377+
"description": TEXTS["charts"]["ionmob_performance"]["description"],
378+
"chart": scatter_chart.to_html(**PLOTLY_HTML_KWARGS),
379+
}
380+
)
394381
return context
395382

396383

0 commit comments

Comments
 (0)