|
1 | 1 | """Collection of Plotly-based charts for reporting results of MS²Rescore.""" |
2 | 2 |
|
| 3 | +import importlib.resources |
3 | 4 | import warnings |
4 | 5 | from collections import defaultdict |
| 6 | +from pathlib import Path |
5 | 7 | from typing import Dict, List, Optional, Tuple, Union |
6 | 8 |
|
7 | 9 | import mokapot |
@@ -635,3 +637,155 @@ def feature_ecdf_auc_bar( |
635 | 637 | }, |
636 | 638 | color_discrete_map=color_discrete_map, |
637 | 639 | ) |
| 640 | + |
| 641 | + |
| 642 | +def rt_scatter( |
| 643 | + df: pd.DataFrame, |
| 644 | + predicted_column: str = "Predicted retention time", |
| 645 | + observed_column: str = "Observed retention time", |
| 646 | + xaxis_label: str = "Observed retention time", |
| 647 | + yaxis_label: str = "Predicted retention time", |
| 648 | + plot_title: str = "Predicted vs. observed retention times", |
| 649 | +) -> go.Figure: |
| 650 | + """ |
| 651 | + Plot a scatter plot of the predicted vs. observed retention times. |
| 652 | +
|
| 653 | + Parameters |
| 654 | + ---------- |
| 655 | + df : pd.DataFrame |
| 656 | + Dataframe containing the predicted and observed retention times. |
| 657 | + predicted_column : str, optional |
| 658 | + Name of the column containing the predicted retention times, by default |
| 659 | + ``Predicted retention time``. |
| 660 | + observed_column : str, optional |
| 661 | + Name of the column containing the observed retention times, by default |
| 662 | + ``Observed retention time``. |
| 663 | + xaxis_label : str, optional |
| 664 | + X-axis label, by default ``Observed retention time``. |
| 665 | + yaxis_label : str, optional |
| 666 | + Y-axis label, by default ``Predicted retention time``. |
| 667 | + plot_title : str, optional |
| 668 | + Scatter plot title, by default ``Predicted vs. observed retention times`` |
| 669 | +
|
| 670 | + """ |
| 671 | + # Draw scatter |
| 672 | + fig = px.scatter( |
| 673 | + df, |
| 674 | + x=observed_column, |
| 675 | + y=predicted_column, |
| 676 | + opacity=0.3, |
| 677 | + ) |
| 678 | + |
| 679 | + # Draw diagonal line |
| 680 | + fig.add_scatter( |
| 681 | + x=[min(df[observed_column]), max(df[observed_column])], |
| 682 | + y=[min(df[observed_column]), max(df[observed_column])], |
| 683 | + mode="lines", |
| 684 | + line=dict(color="red", width=3, dash="dash"), |
| 685 | + ) |
| 686 | + |
| 687 | + # Hide legend |
| 688 | + fig.update_layout( |
| 689 | + title=plot_title, |
| 690 | + showlegend=False, |
| 691 | + xaxis_title=xaxis_label, |
| 692 | + yaxis_title=yaxis_label, |
| 693 | + ) |
| 694 | + |
| 695 | + return fig |
| 696 | + |
| 697 | + |
| 698 | +def rt_distribution_baseline( |
| 699 | + df: pd.DataFrame, |
| 700 | + predicted_column: str = "Predicted retention time", |
| 701 | + observed_column: str = "Observed retention time", |
| 702 | +) -> go.Figure: |
| 703 | + """ |
| 704 | + Plot a distribution plot of the relative mean absolute error of the current |
| 705 | + DeepLC performance compared to the baseline performance. |
| 706 | +
|
| 707 | + Parameters |
| 708 | + ---------- |
| 709 | + df : pd.DataFrame |
| 710 | + Dataframe containing the predicted and observed retention times. |
| 711 | + predicted_column : str, optional |
| 712 | + Name of the column containing the predicted retention times, by default |
| 713 | + ``Predicted retention time``. |
| 714 | + observed_column : str, optional |
| 715 | + Name of the column containing the observed retention times, by default |
| 716 | + ``Observed retention time``. |
| 717 | +
|
| 718 | + """ |
| 719 | + # Get baseline data from deeplc package |
| 720 | + try: |
| 721 | + import deeplc.package_data |
| 722 | + |
| 723 | + baseline_path = ( |
| 724 | + Path(importlib.resources.files(deeplc.package_data)) |
| 725 | + / "baseline_performance" |
| 726 | + / "baseline_predictions.csv" |
| 727 | + ) |
| 728 | + baseline_df = pd.read_csv(baseline_path) |
| 729 | + except (ImportError, FileNotFoundError): |
| 730 | + # If deeplc is not installed or baseline data not found, return empty figure |
| 731 | + fig = go.Figure() |
| 732 | + fig.add_annotation( |
| 733 | + text="DeepLC baseline data not available. Install DeepLC to view performance comparison.", |
| 734 | + showarrow=False, |
| 735 | + ) |
| 736 | + return fig |
| 737 | + |
| 738 | + baseline_df["rel_mae_best"] = baseline_df[ |
| 739 | + ["rel_mae_transfer_learning", "rel_mae_new_model", "rel_mae_calibrate"] |
| 740 | + ].min(axis=1) |
| 741 | + baseline_df.fillna(0.0, inplace=True) |
| 742 | + |
| 743 | + # Calculate current RMAE and percentile compared to baseline |
| 744 | + mae = sum(abs(df[observed_column] - df[predicted_column])) / len(df.index) |
| 745 | + mae_rel = (mae / max(df[observed_column])) * 100 |
| 746 | + percentile = round((baseline_df["rel_mae_transfer_learning"] < mae_rel).mean() * 100, 1) |
| 747 | + |
| 748 | + # Calculate x-axis range with 5% padding |
| 749 | + all_values = np.append(baseline_df["rel_mae_transfer_learning"].values, mae_rel) |
| 750 | + padding = (all_values.max() - all_values.min()) / 20 # 5% padding |
| 751 | + x_min = all_values.min() - padding |
| 752 | + x_max = all_values.max() + padding |
| 753 | + |
| 754 | + # Make labels human-readable |
| 755 | + hover_label_mapping = { |
| 756 | + "train_number": "Training dataset size", |
| 757 | + "rel_mae_transfer_learning": "RMAE with transfer learning", |
| 758 | + "rel_mae_new_model": "RMAE with new model from scratch", |
| 759 | + "rel_mae_calibrate": "RMAE with calibrating existing model", |
| 760 | + "rel_mae_best": "RMAE with best method", |
| 761 | + } |
| 762 | + label_mapping = hover_label_mapping.copy() |
| 763 | + label_mapping.update({"Unnamed: 0": "Dataset"}) |
| 764 | + |
| 765 | + # Generate plot |
| 766 | + fig = px.histogram( |
| 767 | + data_frame=baseline_df, |
| 768 | + x="rel_mae_best", |
| 769 | + marginal="rug", |
| 770 | + hover_data=hover_label_mapping.keys(), |
| 771 | + hover_name="Unnamed: 0", |
| 772 | + labels=label_mapping, |
| 773 | + opacity=0.8, |
| 774 | + ) |
| 775 | + fig.add_vline( |
| 776 | + x=mae_rel, |
| 777 | + line_width=3, |
| 778 | + line_dash="dash", |
| 779 | + line_color="red", |
| 780 | + annotation_text=f"Current performance (percentile {percentile}%)", |
| 781 | + annotation_position="top left", |
| 782 | + name="Current performance", |
| 783 | + row=1, |
| 784 | + ) |
| 785 | + fig.update_xaxes(range=[x_min, x_max]) |
| 786 | + fig.update_layout( |
| 787 | + title=(f"Current DeepLC performance compared to {len(baseline_df.index)} datasets"), |
| 788 | + xaxis_title="Relative mean absolute error (%)", |
| 789 | + ) |
| 790 | + |
| 791 | + return fig |
0 commit comments