diff --git a/docs/source/api.rst b/docs/source/api.rst index 8e76213..62876c5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -83,6 +83,16 @@ Base Architecture & Protocols :undoc-members: :show-inheritance: +.. automodule:: tab_right.base_architecture.model_comparison_protocols + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tab_right.base_architecture.model_comparison_plot_protocols + :members: + :undoc-members: + :show-inheritance: + Task Detection -------------- diff --git a/tab_right/base_architecture/model_comparison_plot_protocols.py b/tab_right/base_architecture/model_comparison_plot_protocols.py new file mode 100644 index 0000000..bab96bd --- /dev/null +++ b/tab_right/base_architecture/model_comparison_plot_protocols.py @@ -0,0 +1,123 @@ +"""Protocols for model comparison visualization in tab-right. + +This module defines protocols for visualizing model comparison results. It provides +interfaces for creating visualizations that compare multiple prediction datasets +against true labels. +""" + +from dataclasses import dataclass +from typing import Any, List, Optional, Protocol, Tuple, Union, runtime_checkable + +import matplotlib.pyplot as plt +import pandas as pd +import plotly.graph_objects as go + +from .model_comparison_protocols import PredictionCalculationP + +Figure = Union[go.Figure, plt.Figure] + + +@runtime_checkable +@dataclass +class ModelComparisonPlotP(Protocol): + """Protocol for model comparison visualization implementations. + + This protocol defines the interface for visualizing comparison results + between multiple predictions and true labels. + + Parameters + ---------- + comparison_calc : PredictionCalculationP + An implementation of PredictionCalculationP that provides the comparison + metrics to visualize. + + """ + + comparison_calc: PredictionCalculationP + + def plot_error_distribution( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 8), + bins: int = 30, + **kwargs: Any, + ) -> Figure: + """Create a visualization comparing error distributions across models. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names like "Model 0", "Model 1", etc. + figsize : Tuple[int, int], default (12, 8) + Figure size as (width, height) in inches. + bins : int, default 30 + Number of bins for histogram visualization. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the error distribution comparison. + + """ + + def plot_pairwise_comparison( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (10, 8), + **kwargs: Any, + ) -> Figure: + """Create a pairwise comparison plot between models. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names. + figsize : Tuple[int, int], default (10, 8) + Figure size as (width, height) in inches. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the pairwise comparison visualization. + + """ + + def plot_model_performance_summary( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + metrics: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 6), + **kwargs: Any, + ) -> Figure: + """Create a summary visualization of model performance metrics. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names. + metrics : Optional[List[str]], default None + List of metrics to display. If None, uses default metrics. + figsize : Tuple[int, int], default (12, 6) + Figure size as (width, height) in inches. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the performance summary visualization. + + """ diff --git a/tab_right/base_architecture/model_comparison_protocols.py b/tab_right/base_architecture/model_comparison_protocols.py new file mode 100644 index 0000000..c6fd9f4 --- /dev/null +++ b/tab_right/base_architecture/model_comparison_protocols.py @@ -0,0 +1,61 @@ +"""Protocol definitions for model comparison analysis in tab-right. + +This module defines protocol classes for model comparison analysis, +including interfaces for prediction calculation and comparison between multiple models. +""" + +from dataclasses import dataclass +from typing import Callable, List, Optional, Protocol, runtime_checkable + +import pandas as pd + + +@runtime_checkable +@dataclass +class PredictionCalculationP(Protocol): + """Protocol for prediction calculation implementations. + + This protocol defines the interface for calculating pointwise errors + between multiple sets of predictions and true labels. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the data for analysis. + label_col : str + Column name for the true target values. + + """ + + df: pd.DataFrame + label_col: str + + def __call__( + self, + pred_data: List[pd.Series], + error_func: Optional[Callable[[pd.Series, pd.Series], pd.Series]] = None, + ) -> pd.DataFrame: + """Calculate pointwise errors for multiple prediction series against the label. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + Each Series should have the same index as the DataFrame. + error_func : Optional[Callable[[pd.Series, pd.Series], pd.Series]], default None + Function for calculating pairwise error between predictions and labels. + If None, defaults to a standard metric (e.g., absolute error for regression, + 0/1 loss for classification). + Function signature: error_func(y_true, y_pred) -> pd.Series + + Returns + ------- + pd.DataFrame + DataFrame containing pointwise errors for each prediction series. + Expected columns: + - Original DataFrame columns (as context) + - `{label_col}`: The true label values + - `pred_0_error`, `pred_1_error`, ...: Pointwise errors for each prediction series + - `model_id`: Optional identifier for the prediction series + + """ diff --git a/tests/base_architecture/model_comparison/__init__.py b/tests/base_architecture/model_comparison/__init__.py new file mode 100644 index 0000000..451c1fa --- /dev/null +++ b/tests/base_architecture/model_comparison/__init__.py @@ -0,0 +1 @@ +"""Tests for model comparison protocols.""" diff --git a/tests/base_architecture/model_comparison/model_comparison_check.py b/tests/base_architecture/model_comparison/model_comparison_check.py new file mode 100644 index 0000000..c8b5fe8 --- /dev/null +++ b/tests/base_architecture/model_comparison/model_comparison_check.py @@ -0,0 +1,84 @@ +import pandas as pd + +from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP + +from ..base_protocols_check import CheckProtocols + + +class CheckPredictionCalculation(CheckProtocols): + """Class for checking compliance of `PredictionCalculationP` protocol.""" + + # Use the protocol type directly + class_to_check = PredictionCalculationP + + def test_attributes(self, instance_to_check: PredictionCalculationP) -> None: + """Test attributes of the instance to ensure compliance.""" + assert hasattr(instance_to_check, "df") + assert hasattr(instance_to_check, "label_col") + assert isinstance(instance_to_check.df, pd.DataFrame) + assert isinstance(instance_to_check.label_col, str) + assert instance_to_check.label_col in instance_to_check.df.columns + + def test_call_method(self, instance_to_check: PredictionCalculationP) -> None: + """Test the __call__ method of the instance.""" + # Create test prediction data + n_samples = len(instance_to_check.df) + pred_data = [ + pd.Series(range(n_samples), index=instance_to_check.df.index, name="pred_0"), + pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.df.index, name="pred_1"), + ] + + # Test with default error function + result = instance_to_check(pred_data) + assert isinstance(result, pd.DataFrame) + + # Check that the result contains original DataFrame columns + for col in instance_to_check.df.columns: + assert col in result.columns + + # Check that error columns are present + assert "pred_0_error" in result.columns + assert "pred_1_error" in result.columns + + # Test with custom error function + def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series: + return abs(y_true - y_pred) + + result_custom = instance_to_check(pred_data, error_func=custom_error) + assert isinstance(result_custom, pd.DataFrame) + assert "pred_0_error" in result_custom.columns + assert "pred_1_error" in result_custom.columns + + # Test with single prediction + single_pred = [pred_data[0]] + result_single = instance_to_check(single_pred) + assert isinstance(result_single, pd.DataFrame) + assert "pred_0_error" in result_single.columns + assert "pred_1_error" not in result_single.columns + + def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None: + """Test that error calculations work correctly.""" + # Create test data where we know the expected errors + len(instance_to_check.df) + label_values = instance_to_check.df[instance_to_check.label_col] + + # Create prediction that should have zero error + pred_exact: pd.Series = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact") + + # Create prediction with constant offset + pred_offset: pd.Series = pd.Series( + label_values.to_numpy() + 1, index=instance_to_check.df.index, name="pred_offset" + ) + + pred_data = [pred_exact, pred_offset] + + # Test with default error function + result = instance_to_check(pred_data) + + # Check that exact prediction has zero error (or very close to zero) + exact_errors = result["pred_0_error"] + assert (exact_errors < 1e-10).all() or (exact_errors == 0).all() + + # Check that offset prediction has consistent non-zero error + offset_errors = result["pred_1_error"] + assert (offset_errors > 0).all() diff --git a/tests/base_architecture/model_comparison/model_comparison_plot_check.py b/tests/base_architecture/model_comparison/model_comparison_plot_check.py new file mode 100644 index 0000000..7987df6 --- /dev/null +++ b/tests/base_architecture/model_comparison/model_comparison_plot_check.py @@ -0,0 +1,49 @@ +import pandas as pd + +from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP + +from ..base_protocols_check import CheckProtocols + + +class CheckModelComparisonPlot(CheckProtocols): + """Class for checking compliance of `ModelComparisonPlotP` protocol.""" + + # Use the protocol type directly + class_to_check = ModelComparisonPlotP + + def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test attributes of the instance to ensure compliance.""" + assert hasattr(instance_to_check, "comparison_calc") + # The comparison_calc should implement PredictionCalculationP protocol + assert hasattr(instance_to_check.comparison_calc, "df") + assert hasattr(instance_to_check.comparison_calc, "label_col") + assert callable(instance_to_check.comparison_calc) + + def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_error_distribution method.""" + # Create test prediction data + n_samples = len(instance_to_check.comparison_calc.df) + [ + pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"), + pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"), + ] + + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_error_distribution") + assert callable(instance_to_check.plot_error_distribution) + + # Test that method can be called with default parameters + # Note: We're not testing the actual plotting since these are protocols + # The implementation would be tested separately + + def test_plot_pairwise_comparison(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_pairwise_comparison method.""" + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_pairwise_comparison") + assert callable(instance_to_check.plot_pairwise_comparison) + + def test_plot_model_performance_summary(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_model_performance_summary method.""" + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_model_performance_summary") + assert callable(instance_to_check.plot_model_performance_summary)