Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ Base Architecture & Protocols
:undoc-members:
:show-inheritance:

.. automodule:: tab_right.base_architecture.model_comparison_protocols
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tab_right.base_architecture.model_comparison_plot_protocols
:members:
:undoc-members:
:show-inheritance:

Task Detection
--------------

Expand Down
126 changes: 126 additions & 0 deletions tab_right/base_architecture/model_comparison_plot_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Protocols for model comparison visualization in tab-right.

This module defines protocols for visualizing model comparison results. It provides
interfaces for creating visualizations that compare multiple prediction datasets
against true labels.
"""

from dataclasses import dataclass
from typing import Any, List, Optional, Protocol, Tuple, Union, runtime_checkable

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

from .model_comparison_protocols import PredictionCalculationP

Figure = Union[go.Figure, plt.Figure]


@runtime_checkable
@dataclass
class ModelComparisonPlotP(Protocol):
"""Protocol for model comparison visualization implementations.

This protocol defines the interface for visualizing comparison results
between multiple predictions and true labels.

Parameters
----------
comparison_calc : PredictionCalculationP
An implementation of PredictionCalculationP that provides the comparison
metrics to visualize.

"""

comparison_calc: PredictionCalculationP

def plot_error_distribution(
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (12, 8),
bins: int = 30,
**kwargs: Any,
) -> Figure:
"""Create a visualization comparing error distributions across models.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names like "Model 0", "Model 1", etc.
figsize : Tuple[int, int], default (12, 8)
Figure size as (width, height) in inches.
bins : int, default 30
Number of bins for histogram visualization.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the error distribution comparison.

"""
...

Check warning on line 67 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L67

Added line #L67 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
...


def plot_pairwise_comparison(
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (10, 8),
**kwargs: Any,
) -> Figure:
"""Create a pairwise comparison plot between models.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names.
figsize : Tuple[int, int], default (10, 8)
Figure size as (width, height) in inches.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the pairwise comparison visualization.

"""
...

Check warning on line 95 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L95

Added line #L95 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
...


def plot_model_performance_summary(
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
metrics: Optional[List[str]] = None,
figsize: Tuple[int, int] = (12, 6),
**kwargs: Any,
) -> Figure:
"""Create a summary visualization of model performance metrics.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names.
metrics : Optional[List[str]], default None
List of metrics to display. If None, uses default metrics.
figsize : Tuple[int, int], default (12, 6)
Figure size as (width, height) in inches.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the performance summary visualization.

"""
...

Check warning on line 126 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L126

Added line #L126 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
...

61 changes: 61 additions & 0 deletions tab_right/base_architecture/model_comparison_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Protocol definitions for model comparison analysis in tab-right.

This module defines protocol classes for model comparison analysis,
including interfaces for prediction calculation and comparison between multiple models.
"""

from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol, runtime_checkable

import pandas as pd


@runtime_checkable
@dataclass
class PredictionCalculationP(Protocol):
"""Protocol for prediction calculation implementations.

This protocol defines the interface for calculating pointwise errors
between multiple sets of predictions and true labels.

Parameters
----------
df : pd.DataFrame
A DataFrame containing the data for analysis.
label_col : str
Column name for the true target values.

"""

df: pd.DataFrame
label_col: str

def __call__(
self,
pred_data: List[pd.Series],
error_func: Optional[Callable[[pd.Series, pd.Series], pd.Series]] = None,
) -> pd.DataFrame:
"""Calculate pointwise errors for multiple prediction series against the label.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
Each Series should have the same index as the DataFrame.
error_func : Optional[Callable[[pd.Series, pd.Series], pd.Series]], default None
Function for calculating pairwise error between predictions and labels.
If None, defaults to a standard metric (e.g., absolute error for regression,
0/1 loss for classification).
Function signature: error_func(y_true, y_pred) -> pd.Series

Returns
-------
pd.DataFrame
DataFrame containing pointwise errors for each prediction series.
Expected columns:
- Original DataFrame columns (as context)
- `{label_col}`: The true label values
- `pred_0_error`, `pred_1_error`, ...: Pointwise errors for each prediction series
- `model_id`: Optional identifier for the prediction series

"""
1 change: 1 addition & 0 deletions tests/base_architecture/model_comparison/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for model comparison protocols."""
83 changes: 83 additions & 0 deletions tests/base_architecture/model_comparison/model_comparison_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pandas as pd
import pytest

from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP

from ..base_protocols_check import CheckProtocols


class CheckPredictionCalculation(CheckProtocols):
"""Class for checking compliance of `PredictionCalculationP` protocol."""

# Use the protocol type directly
class_to_check = PredictionCalculationP

def test_attributes(self, instance_to_check: PredictionCalculationP) -> None:
"""Test attributes of the instance to ensure compliance."""
assert hasattr(instance_to_check, "df")
assert hasattr(instance_to_check, "label_col")
assert isinstance(instance_to_check.df, pd.DataFrame)
assert isinstance(instance_to_check.label_col, str)
assert instance_to_check.label_col in instance_to_check.df.columns

def test_call_method(self, instance_to_check: PredictionCalculationP) -> None:
"""Test the __call__ method of the instance."""
# Create test prediction data
n_samples = len(instance_to_check.df)
pred_data = [
pd.Series(range(n_samples), index=instance_to_check.df.index, name="pred_0"),
pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.df.index, name="pred_1"),
]

# Test with default error function
result = instance_to_check(pred_data)
assert isinstance(result, pd.DataFrame)

# Check that the result contains original DataFrame columns
for col in instance_to_check.df.columns:
assert col in result.columns

# Check that error columns are present
assert "pred_0_error" in result.columns
assert "pred_1_error" in result.columns

# Test with custom error function
def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
return abs(y_true - y_pred)

result_custom = instance_to_check(pred_data, error_func=custom_error)
assert isinstance(result_custom, pd.DataFrame)
assert "pred_0_error" in result_custom.columns
assert "pred_1_error" in result_custom.columns

# Test with single prediction
single_pred = [pred_data[0]]
result_single = instance_to_check(single_pred)
assert isinstance(result_single, pd.DataFrame)
assert "pred_0_error" in result_single.columns
assert "pred_1_error" not in result_single.columns

def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None:
"""Test that error calculations work correctly."""
# Create test data where we know the expected errors
n_samples = len(instance_to_check.df)
label_values = instance_to_check.df[instance_to_check.label_col]

# Create prediction that should have zero error
pred_exact = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact")

# Create prediction with constant offset
pred_offset = pd.Series(label_values.values + 1, index=instance_to_check.df.index, name="pred_offset")

pred_data = [pred_exact, pred_offset]

# Test with default error function
result = instance_to_check(pred_data)

# Check that exact prediction has zero error (or very close to zero)
exact_errors = result["pred_0_error"]
assert (exact_errors < 1e-10).all() or (exact_errors == 0).all()

# Check that offset prediction has consistent non-zero error
offset_errors = result["pred_1_error"]
assert (offset_errors > 0).all()
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pandas as pd

from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP

from ..base_protocols_check import CheckProtocols


class CheckModelComparisonPlot(CheckProtocols):
"""Class for checking compliance of `ModelComparisonPlotP` protocol."""

# Use the protocol type directly
class_to_check = ModelComparisonPlotP

def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test attributes of the instance to ensure compliance."""
assert hasattr(instance_to_check, "comparison_calc")
# The comparison_calc should implement PredictionCalculationP protocol
assert hasattr(instance_to_check.comparison_calc, "df")
assert hasattr(instance_to_check.comparison_calc, "label_col")
assert hasattr(instance_to_check.comparison_calc, "__call__")

def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_error_distribution method."""
# Create test prediction data
n_samples = len(instance_to_check.comparison_calc.df)
pred_data = [
pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"),
pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"),
]

# Test method exists and is callable
assert hasattr(instance_to_check, "plot_error_distribution")
assert callable(instance_to_check.plot_error_distribution)

# Test that method can be called with default parameters
# Note: We're not testing the actual plotting since these are protocols
# The implementation would be tested separately

def test_plot_pairwise_comparison(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_pairwise_comparison method."""
# Test method exists and is callable
assert hasattr(instance_to_check, "plot_pairwise_comparison")
assert callable(instance_to_check.plot_pairwise_comparison)

def test_plot_model_performance_summary(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_model_performance_summary method."""
# Test method exists and is callable
assert hasattr(instance_to_check, "plot_model_performance_summary")
assert callable(instance_to_check.plot_model_performance_summary)
Loading
Loading