Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ Base Architecture & Protocols
:undoc-members:
:show-inheritance:

.. automodule:: tab_right.base_architecture.model_comparison_protocols
:members:
:undoc-members:
:show-inheritance:

.. automodule:: tab_right.base_architecture.model_comparison_plot_protocols
:members:
:undoc-members:
:show-inheritance:

Task Detection
--------------

Expand Down
123 changes: 123 additions & 0 deletions tab_right/base_architecture/model_comparison_plot_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Protocols for model comparison visualization in tab-right.

This module defines protocols for visualizing model comparison results. It provides
interfaces for creating visualizations that compare multiple prediction datasets
against true labels.
"""

from dataclasses import dataclass
from typing import Any, List, Optional, Protocol, Tuple, Union, runtime_checkable

Check warning on line 9 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L8-L9

Added lines #L8 - L9 were not covered by tests

import matplotlib.pyplot as plt
import pandas as pd
import plotly.graph_objects as go

Check warning on line 13 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L11-L13

Added lines #L11 - L13 were not covered by tests

from .model_comparison_protocols import PredictionCalculationP

Check warning on line 15 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L15

Added line #L15 was not covered by tests

Figure = Union[go.Figure, plt.Figure]

Check warning on line 17 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L17

Added line #L17 was not covered by tests


@runtime_checkable
@dataclass
class ModelComparisonPlotP(Protocol):

Check warning on line 22 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L20-L22

Added lines #L20 - L22 were not covered by tests
"""Protocol for model comparison visualization implementations.

This protocol defines the interface for visualizing comparison results
between multiple predictions and true labels.

Parameters
----------
comparison_calc : PredictionCalculationP
An implementation of PredictionCalculationP that provides the comparison
metrics to visualize.

"""

comparison_calc: PredictionCalculationP

Check warning on line 36 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L36

Added line #L36 was not covered by tests

def plot_error_distribution(

Check warning on line 38 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L38

Added line #L38 was not covered by tests
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (12, 8),
bins: int = 30,
**kwargs: Any,
) -> Figure:
"""Create a visualization comparing error distributions across models.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names like "Model 0", "Model 1", etc.
figsize : Tuple[int, int], default (12, 8)
Figure size as (width, height) in inches.
bins : int, default 30
Number of bins for histogram visualization.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the error distribution comparison.

"""

def plot_pairwise_comparison(

Check warning on line 68 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L68

Added line #L68 was not covered by tests
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
figsize: Tuple[int, int] = (10, 8),
**kwargs: Any,
) -> Figure:
"""Create a pairwise comparison plot between models.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names.
figsize : Tuple[int, int], default (10, 8)
Figure size as (width, height) in inches.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the pairwise comparison visualization.

"""

def plot_model_performance_summary(

Check warning on line 95 in tab_right/base_architecture/model_comparison_plot_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_plot_protocols.py#L95

Added line #L95 was not covered by tests
self,
pred_data: List[pd.Series],
model_names: Optional[List[str]] = None,
metrics: Optional[List[str]] = None,
figsize: Tuple[int, int] = (12, 6),
**kwargs: Any,
) -> Figure:
"""Create a summary visualization of model performance metrics.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
model_names : Optional[List[str]], default None
Names for each model. If None, uses default names.
metrics : Optional[List[str]], default None
List of metrics to display. If None, uses default metrics.
figsize : Tuple[int, int], default (12, 6)
Figure size as (width, height) in inches.
**kwargs : Any
Additional parameters for the plotting implementation.

Returns
-------
Figure
A figure object containing the performance summary visualization.

"""
61 changes: 61 additions & 0 deletions tab_right/base_architecture/model_comparison_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Protocol definitions for model comparison analysis in tab-right.

This module defines protocol classes for model comparison analysis,
including interfaces for prediction calculation and comparison between multiple models.
"""

from dataclasses import dataclass
from typing import Callable, List, Optional, Protocol, runtime_checkable

Check warning on line 8 in tab_right/base_architecture/model_comparison_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_protocols.py#L7-L8

Added lines #L7 - L8 were not covered by tests

import pandas as pd

Check warning on line 10 in tab_right/base_architecture/model_comparison_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_protocols.py#L10

Added line #L10 was not covered by tests


@runtime_checkable
@dataclass
class PredictionCalculationP(Protocol):

Check warning on line 15 in tab_right/base_architecture/model_comparison_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_protocols.py#L13-L15

Added lines #L13 - L15 were not covered by tests
"""Protocol for prediction calculation implementations.

This protocol defines the interface for calculating pointwise errors
between multiple sets of predictions and true labels.

Parameters
----------
df : pd.DataFrame
A DataFrame containing the data for analysis.
label_col : str
Column name for the true target values.

"""

df: pd.DataFrame
label_col: str

Check warning on line 31 in tab_right/base_architecture/model_comparison_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_protocols.py#L30-L31

Added lines #L30 - L31 were not covered by tests

def __call__(

Check warning on line 33 in tab_right/base_architecture/model_comparison_protocols.py

View check run for this annotation

Codecov / codecov/patch

tab_right/base_architecture/model_comparison_protocols.py#L33

Added line #L33 was not covered by tests
self,
pred_data: List[pd.Series],
error_func: Optional[Callable[[pd.Series, pd.Series], pd.Series]] = None,
) -> pd.DataFrame:
"""Calculate pointwise errors for multiple prediction series against the label.

Parameters
----------
pred_data : List[pd.Series]
List of prediction Series to compare against the label.
Each Series should have the same index as the DataFrame.
error_func : Optional[Callable[[pd.Series, pd.Series], pd.Series]], default None
Function for calculating pairwise error between predictions and labels.
If None, defaults to a standard metric (e.g., absolute error for regression,
0/1 loss for classification).
Function signature: error_func(y_true, y_pred) -> pd.Series

Returns
-------
pd.DataFrame
DataFrame containing pointwise errors for each prediction series.
Expected columns:
- Original DataFrame columns (as context)
- `{label_col}`: The true label values
- `pred_0_error`, `pred_1_error`, ...: Pointwise errors for each prediction series
- `model_id`: Optional identifier for the prediction series

"""
1 change: 1 addition & 0 deletions tests/base_architecture/model_comparison/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for model comparison protocols."""
84 changes: 84 additions & 0 deletions tests/base_architecture/model_comparison/model_comparison_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pandas as pd

from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP

from ..base_protocols_check import CheckProtocols


class CheckPredictionCalculation(CheckProtocols):
"""Class for checking compliance of `PredictionCalculationP` protocol."""

# Use the protocol type directly
class_to_check = PredictionCalculationP

def test_attributes(self, instance_to_check: PredictionCalculationP) -> None:
"""Test attributes of the instance to ensure compliance."""
assert hasattr(instance_to_check, "df")
assert hasattr(instance_to_check, "label_col")
assert isinstance(instance_to_check.df, pd.DataFrame)
assert isinstance(instance_to_check.label_col, str)
assert instance_to_check.label_col in instance_to_check.df.columns

def test_call_method(self, instance_to_check: PredictionCalculationP) -> None:
"""Test the __call__ method of the instance."""
# Create test prediction data
n_samples = len(instance_to_check.df)
pred_data = [
pd.Series(range(n_samples), index=instance_to_check.df.index, name="pred_0"),
pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.df.index, name="pred_1"),
]

# Test with default error function
result = instance_to_check(pred_data)
assert isinstance(result, pd.DataFrame)

# Check that the result contains original DataFrame columns
for col in instance_to_check.df.columns:
assert col in result.columns

# Check that error columns are present
assert "pred_0_error" in result.columns
assert "pred_1_error" in result.columns

# Test with custom error function
def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series:
return abs(y_true - y_pred)

result_custom = instance_to_check(pred_data, error_func=custom_error)
assert isinstance(result_custom, pd.DataFrame)
assert "pred_0_error" in result_custom.columns
assert "pred_1_error" in result_custom.columns

# Test with single prediction
single_pred = [pred_data[0]]
result_single = instance_to_check(single_pred)
assert isinstance(result_single, pd.DataFrame)
assert "pred_0_error" in result_single.columns
assert "pred_1_error" not in result_single.columns

def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None:
"""Test that error calculations work correctly."""
# Create test data where we know the expected errors
len(instance_to_check.df)
label_values = instance_to_check.df[instance_to_check.label_col]

# Create prediction that should have zero error
pred_exact: pd.Series = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact")

# Create prediction with constant offset
pred_offset: pd.Series = pd.Series(
label_values.to_numpy() + 1, index=instance_to_check.df.index, name="pred_offset"
)

pred_data = [pred_exact, pred_offset]

# Test with default error function
result = instance_to_check(pred_data)

# Check that exact prediction has zero error (or very close to zero)
exact_errors = result["pred_0_error"]
assert (exact_errors < 1e-10).all() or (exact_errors == 0).all()

# Check that offset prediction has consistent non-zero error
offset_errors = result["pred_1_error"]
assert (offset_errors > 0).all()
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pandas as pd

from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP

from ..base_protocols_check import CheckProtocols


class CheckModelComparisonPlot(CheckProtocols):
"""Class for checking compliance of `ModelComparisonPlotP` protocol."""

# Use the protocol type directly
class_to_check = ModelComparisonPlotP

def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test attributes of the instance to ensure compliance."""
assert hasattr(instance_to_check, "comparison_calc")
# The comparison_calc should implement PredictionCalculationP protocol
assert hasattr(instance_to_check.comparison_calc, "df")
assert hasattr(instance_to_check.comparison_calc, "label_col")
assert callable(instance_to_check.comparison_calc)

def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_error_distribution method."""
# Create test prediction data
n_samples = len(instance_to_check.comparison_calc.df)
[
pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"),
pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"),
]

# Test method exists and is callable
assert hasattr(instance_to_check, "plot_error_distribution")
assert callable(instance_to_check.plot_error_distribution)

# Test that method can be called with default parameters
# Note: We're not testing the actual plotting since these are protocols
# The implementation would be tested separately

def test_plot_pairwise_comparison(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_pairwise_comparison method."""
# Test method exists and is callable
assert hasattr(instance_to_check, "plot_pairwise_comparison")
assert callable(instance_to_check.plot_pairwise_comparison)

def test_plot_model_performance_summary(self, instance_to_check: ModelComparisonPlotP) -> None:
"""Test the plot_model_performance_summary method."""
# Test method exists and is callable
assert hasattr(instance_to_check, "plot_model_performance_summary")
assert callable(instance_to_check.plot_model_performance_summary)