From 85cdb5187ed7ebeae009223b6d7ae14231a9f06b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Jun 2025 14:50:55 +0000 Subject: [PATCH 1/5] Initial plan for issue From 1e9ea4958a24cbd16ba11eb0fe9b6560b2abccd0 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Jun 2025 15:00:32 +0000 Subject: [PATCH 2/5] Implement model comparison protocol classes and tests Co-authored-by: eh-main-bot <171766998+eh-main-bot@users.noreply.github.com> --- docs/source/api.rst | 10 ++ .../model_comparison_plot_protocols.py | 126 ++++++++++++++++++ .../model_comparison_protocols.py | 61 +++++++++ .../model_comparison/__init__.py | 1 + .../model_comparison_check.py | 83 ++++++++++++ .../model_comparison_plot_check.py | 49 +++++++ .../test_model_comparison_plot_protocol.py | 97 ++++++++++++++ .../test_model_comparison_protocol.py | 53 ++++++++ 8 files changed, 480 insertions(+) create mode 100644 tab_right/base_architecture/model_comparison_plot_protocols.py create mode 100644 tab_right/base_architecture/model_comparison_protocols.py create mode 100644 tests/base_architecture/model_comparison/__init__.py create mode 100644 tests/base_architecture/model_comparison/model_comparison_check.py create mode 100644 tests/base_architecture/model_comparison/model_comparison_plot_check.py create mode 100644 tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py create mode 100644 tests/base_architecture/model_comparison/test_model_comparison_protocol.py diff --git a/docs/source/api.rst b/docs/source/api.rst index 8e76213..62876c5 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -83,6 +83,16 @@ Base Architecture & Protocols :undoc-members: :show-inheritance: +.. automodule:: tab_right.base_architecture.model_comparison_protocols + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: tab_right.base_architecture.model_comparison_plot_protocols + :members: + :undoc-members: + :show-inheritance: + Task Detection -------------- diff --git a/tab_right/base_architecture/model_comparison_plot_protocols.py b/tab_right/base_architecture/model_comparison_plot_protocols.py new file mode 100644 index 0000000..4154277 --- /dev/null +++ b/tab_right/base_architecture/model_comparison_plot_protocols.py @@ -0,0 +1,126 @@ +"""Protocols for model comparison visualization in tab-right. + +This module defines protocols for visualizing model comparison results. It provides +interfaces for creating visualizations that compare multiple prediction datasets +against true labels. +""" + +from dataclasses import dataclass +from typing import Any, List, Optional, Protocol, Tuple, Union, runtime_checkable + +import matplotlib.pyplot as plt +import pandas as pd +import plotly.graph_objects as go + +from .model_comparison_protocols import PredictionCalculationP + +Figure = Union[go.Figure, plt.Figure] + + +@runtime_checkable +@dataclass +class ModelComparisonPlotP(Protocol): + """Protocol for model comparison visualization implementations. + + This protocol defines the interface for visualizing comparison results + between multiple predictions and true labels. + + Parameters + ---------- + comparison_calc : PredictionCalculationP + An implementation of PredictionCalculationP that provides the comparison + metrics to visualize. + + """ + + comparison_calc: PredictionCalculationP + + def plot_error_distribution( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 8), + bins: int = 30, + **kwargs: Any, + ) -> Figure: + """Create a visualization comparing error distributions across models. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names like "Model 0", "Model 1", etc. + figsize : Tuple[int, int], default (12, 8) + Figure size as (width, height) in inches. + bins : int, default 30 + Number of bins for histogram visualization. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the error distribution comparison. + + """ + ... + + def plot_pairwise_comparison( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (10, 8), + **kwargs: Any, + ) -> Figure: + """Create a pairwise comparison plot between models. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names. + figsize : Tuple[int, int], default (10, 8) + Figure size as (width, height) in inches. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the pairwise comparison visualization. + + """ + ... + + def plot_model_performance_summary( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + metrics: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 6), + **kwargs: Any, + ) -> Figure: + """Create a summary visualization of model performance metrics. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + model_names : Optional[List[str]], default None + Names for each model. If None, uses default names. + metrics : Optional[List[str]], default None + List of metrics to display. If None, uses default metrics. + figsize : Tuple[int, int], default (12, 6) + Figure size as (width, height) in inches. + **kwargs : Any + Additional parameters for the plotting implementation. + + Returns + ------- + Figure + A figure object containing the performance summary visualization. + + """ + ... diff --git a/tab_right/base_architecture/model_comparison_protocols.py b/tab_right/base_architecture/model_comparison_protocols.py new file mode 100644 index 0000000..c6fd9f4 --- /dev/null +++ b/tab_right/base_architecture/model_comparison_protocols.py @@ -0,0 +1,61 @@ +"""Protocol definitions for model comparison analysis in tab-right. + +This module defines protocol classes for model comparison analysis, +including interfaces for prediction calculation and comparison between multiple models. +""" + +from dataclasses import dataclass +from typing import Callable, List, Optional, Protocol, runtime_checkable + +import pandas as pd + + +@runtime_checkable +@dataclass +class PredictionCalculationP(Protocol): + """Protocol for prediction calculation implementations. + + This protocol defines the interface for calculating pointwise errors + between multiple sets of predictions and true labels. + + Parameters + ---------- + df : pd.DataFrame + A DataFrame containing the data for analysis. + label_col : str + Column name for the true target values. + + """ + + df: pd.DataFrame + label_col: str + + def __call__( + self, + pred_data: List[pd.Series], + error_func: Optional[Callable[[pd.Series, pd.Series], pd.Series]] = None, + ) -> pd.DataFrame: + """Calculate pointwise errors for multiple prediction series against the label. + + Parameters + ---------- + pred_data : List[pd.Series] + List of prediction Series to compare against the label. + Each Series should have the same index as the DataFrame. + error_func : Optional[Callable[[pd.Series, pd.Series], pd.Series]], default None + Function for calculating pairwise error between predictions and labels. + If None, defaults to a standard metric (e.g., absolute error for regression, + 0/1 loss for classification). + Function signature: error_func(y_true, y_pred) -> pd.Series + + Returns + ------- + pd.DataFrame + DataFrame containing pointwise errors for each prediction series. + Expected columns: + - Original DataFrame columns (as context) + - `{label_col}`: The true label values + - `pred_0_error`, `pred_1_error`, ...: Pointwise errors for each prediction series + - `model_id`: Optional identifier for the prediction series + + """ diff --git a/tests/base_architecture/model_comparison/__init__.py b/tests/base_architecture/model_comparison/__init__.py new file mode 100644 index 0000000..1ab3659 --- /dev/null +++ b/tests/base_architecture/model_comparison/__init__.py @@ -0,0 +1 @@ +"""Tests for model comparison protocols.""" \ No newline at end of file diff --git a/tests/base_architecture/model_comparison/model_comparison_check.py b/tests/base_architecture/model_comparison/model_comparison_check.py new file mode 100644 index 0000000..6810f4b --- /dev/null +++ b/tests/base_architecture/model_comparison/model_comparison_check.py @@ -0,0 +1,83 @@ +import pandas as pd +import pytest + +from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP + +from ..base_protocols_check import CheckProtocols + + +class CheckPredictionCalculation(CheckProtocols): + """Class for checking compliance of `PredictionCalculationP` protocol.""" + + # Use the protocol type directly + class_to_check = PredictionCalculationP + + def test_attributes(self, instance_to_check: PredictionCalculationP) -> None: + """Test attributes of the instance to ensure compliance.""" + assert hasattr(instance_to_check, "df") + assert hasattr(instance_to_check, "label_col") + assert isinstance(instance_to_check.df, pd.DataFrame) + assert isinstance(instance_to_check.label_col, str) + assert instance_to_check.label_col in instance_to_check.df.columns + + def test_call_method(self, instance_to_check: PredictionCalculationP) -> None: + """Test the __call__ method of the instance.""" + # Create test prediction data + n_samples = len(instance_to_check.df) + pred_data = [ + pd.Series(range(n_samples), index=instance_to_check.df.index, name="pred_0"), + pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.df.index, name="pred_1"), + ] + + # Test with default error function + result = instance_to_check(pred_data) + assert isinstance(result, pd.DataFrame) + + # Check that the result contains original DataFrame columns + for col in instance_to_check.df.columns: + assert col in result.columns + + # Check that error columns are present + assert "pred_0_error" in result.columns + assert "pred_1_error" in result.columns + + # Test with custom error function + def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series: + return abs(y_true - y_pred) + + result_custom = instance_to_check(pred_data, error_func=custom_error) + assert isinstance(result_custom, pd.DataFrame) + assert "pred_0_error" in result_custom.columns + assert "pred_1_error" in result_custom.columns + + # Test with single prediction + single_pred = [pred_data[0]] + result_single = instance_to_check(single_pred) + assert isinstance(result_single, pd.DataFrame) + assert "pred_0_error" in result_single.columns + assert "pred_1_error" not in result_single.columns + + def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None: + """Test that error calculations work correctly.""" + # Create test data where we know the expected errors + n_samples = len(instance_to_check.df) + label_values = instance_to_check.df[instance_to_check.label_col] + + # Create prediction that should have zero error + pred_exact = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact") + + # Create prediction with constant offset + pred_offset = pd.Series(label_values.values + 1, index=instance_to_check.df.index, name="pred_offset") + + pred_data = [pred_exact, pred_offset] + + # Test with default error function + result = instance_to_check(pred_data) + + # Check that exact prediction has zero error (or very close to zero) + exact_errors = result["pred_0_error"] + assert (exact_errors < 1e-10).all() or (exact_errors == 0).all() + + # Check that offset prediction has consistent non-zero error + offset_errors = result["pred_1_error"] + assert (offset_errors > 0).all() \ No newline at end of file diff --git a/tests/base_architecture/model_comparison/model_comparison_plot_check.py b/tests/base_architecture/model_comparison/model_comparison_plot_check.py new file mode 100644 index 0000000..8506166 --- /dev/null +++ b/tests/base_architecture/model_comparison/model_comparison_plot_check.py @@ -0,0 +1,49 @@ +import pandas as pd + +from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP + +from ..base_protocols_check import CheckProtocols + + +class CheckModelComparisonPlot(CheckProtocols): + """Class for checking compliance of `ModelComparisonPlotP` protocol.""" + + # Use the protocol type directly + class_to_check = ModelComparisonPlotP + + def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test attributes of the instance to ensure compliance.""" + assert hasattr(instance_to_check, "comparison_calc") + # The comparison_calc should implement PredictionCalculationP protocol + assert hasattr(instance_to_check.comparison_calc, "df") + assert hasattr(instance_to_check.comparison_calc, "label_col") + assert hasattr(instance_to_check.comparison_calc, "__call__") + + def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_error_distribution method.""" + # Create test prediction data + n_samples = len(instance_to_check.comparison_calc.df) + pred_data = [ + pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"), + pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"), + ] + + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_error_distribution") + assert callable(instance_to_check.plot_error_distribution) + + # Test that method can be called with default parameters + # Note: We're not testing the actual plotting since these are protocols + # The implementation would be tested separately + + def test_plot_pairwise_comparison(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_pairwise_comparison method.""" + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_pairwise_comparison") + assert callable(instance_to_check.plot_pairwise_comparison) + + def test_plot_model_performance_summary(self, instance_to_check: ModelComparisonPlotP) -> None: + """Test the plot_model_performance_summary method.""" + # Test method exists and is callable + assert hasattr(instance_to_check, "plot_model_performance_summary") + assert callable(instance_to_check.plot_model_performance_summary) \ No newline at end of file diff --git a/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py b/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py new file mode 100644 index 0000000..2aa3487 --- /dev/null +++ b/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py @@ -0,0 +1,97 @@ +import numpy as np +import pandas as pd +import pytest +from typing import List, Optional, Any, Tuple +from dataclasses import dataclass +from matplotlib.figure import Figure as MatplotlibFigure +from plotly.graph_objects import Figure as PlotlyFigure + +from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP +from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP + +from .model_comparison_plot_check import CheckModelComparisonPlot + + +@dataclass +class DummyPredictionCalculation: + """Dummy implementation of PredictionCalculationP for testing.""" + + df: pd.DataFrame + label_col: str + + def __call__(self, pred_data, error_func=None): + """Simple implementation that calculates pointwise errors.""" + result_df = self.df.copy() + + # Default error function (absolute error) + if error_func is None: + def error_func(y_true, y_pred): + return abs(y_true - y_pred) + + # Get label values + y_true = self.df[self.label_col] + + # Calculate errors for each prediction + for i, pred_series in enumerate(pred_data): + error_col = f"pred_{i}_error" + result_df[error_col] = error_func(y_true, pred_series) + + return result_df + + +@dataclass +class DummyModelComparisonPlot: + """Dummy implementation of ModelComparisonPlotP for testing.""" + + comparison_calc: PredictionCalculationP + + def plot_error_distribution( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 8), + bins: int = 30, + **kwargs: Any, + ): + """Dummy implementation - returns None for protocol testing.""" + return None + + def plot_pairwise_comparison( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + figsize: Tuple[int, int] = (10, 8), + **kwargs: Any, + ): + """Dummy implementation - returns None for protocol testing.""" + return None + + def plot_model_performance_summary( + self, + pred_data: List[pd.Series], + model_names: Optional[List[str]] = None, + metrics: Optional[List[str]] = None, + figsize: Tuple[int, int] = (12, 6), + **kwargs: Any, + ): + """Dummy implementation - returns None for protocol testing.""" + return None + + +@pytest.fixture +def instance_to_check() -> ModelComparisonPlotP: + """Provides an instance of ModelComparisonPlotP for protocol testing.""" + # Create a simple test DataFrame + df = pd.DataFrame({ + "feature1": np.random.rand(20), + "feature2": np.random.choice(["A", "B", "C"], 20), + "label": np.random.rand(20) + }) + comparison_calc = DummyPredictionCalculation(df, "label") + return DummyModelComparisonPlot(comparison_calc) + + +class TestModelComparisonPlot(CheckModelComparisonPlot): + """Test class for `ModelComparisonPlotP` protocol compliance.""" + + # Note: instance_to_check pytest fixture is implemented above \ No newline at end of file diff --git a/tests/base_architecture/model_comparison/test_model_comparison_protocol.py b/tests/base_architecture/model_comparison/test_model_comparison_protocol.py new file mode 100644 index 0000000..9cc336e --- /dev/null +++ b/tests/base_architecture/model_comparison/test_model_comparison_protocol.py @@ -0,0 +1,53 @@ +import numpy as np +import pandas as pd +import pytest +from dataclasses import dataclass + +from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP + +from .model_comparison_check import CheckPredictionCalculation + + +@dataclass +class DummyPredictionCalculation: + """Dummy implementation of PredictionCalculationP for testing.""" + + df: pd.DataFrame + label_col: str + + def __call__(self, pred_data, error_func=None): + """Simple implementation that calculates pointwise errors.""" + result_df = self.df.copy() + + # Default error function (absolute error) + if error_func is None: + def error_func(y_true, y_pred): + return abs(y_true - y_pred) + + # Get label values + y_true = self.df[self.label_col] + + # Calculate errors for each prediction + for i, pred_series in enumerate(pred_data): + error_col = f"pred_{i}_error" + result_df[error_col] = error_func(y_true, pred_series) + + return result_df + + +@pytest.fixture +def instance_to_check() -> PredictionCalculationP: + """Provides an instance of PredictionCalculationP for protocol testing.""" + # Create a simple test DataFrame + df = pd.DataFrame({ + "feature1": np.random.rand(20), + "feature2": np.random.choice(["A", "B", "C"], 20), + "label": np.random.rand(20) + }) + return DummyPredictionCalculation(df, "label") + + +class TestPredictionCalculation(CheckPredictionCalculation): + """Test class for `PredictionCalculationP` protocol compliance.""" + + # Note: instance_to_check pytest fixture is implemented above \ No newline at end of file From 69bc74a8dd6bd09876283379cfcded4eeff4a5e4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Jun 2025 15:34:49 +0000 Subject: [PATCH 3/5] Remove test files and fix protocol method style consistency Co-authored-by: eh-main-bot <171766998+eh-main-bot@users.noreply.github.com> --- .../model_comparison_plot_protocols.py | 3 - .../test_model_comparison_plot_protocol.py | 97 ------------------- .../test_model_comparison_protocol.py | 53 ---------- 3 files changed, 153 deletions(-) delete mode 100644 tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py delete mode 100644 tests/base_architecture/model_comparison/test_model_comparison_protocol.py diff --git a/tab_right/base_architecture/model_comparison_plot_protocols.py b/tab_right/base_architecture/model_comparison_plot_protocols.py index 4154277..bab96bd 100644 --- a/tab_right/base_architecture/model_comparison_plot_protocols.py +++ b/tab_right/base_architecture/model_comparison_plot_protocols.py @@ -64,7 +64,6 @@ def plot_error_distribution( A figure object containing the error distribution comparison. """ - ... def plot_pairwise_comparison( self, @@ -92,7 +91,6 @@ def plot_pairwise_comparison( A figure object containing the pairwise comparison visualization. """ - ... def plot_model_performance_summary( self, @@ -123,4 +121,3 @@ def plot_model_performance_summary( A figure object containing the performance summary visualization. """ - ... diff --git a/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py b/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py deleted file mode 100644 index 2aa3487..0000000 --- a/tests/base_architecture/model_comparison/test_model_comparison_plot_protocol.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -import pandas as pd -import pytest -from typing import List, Optional, Any, Tuple -from dataclasses import dataclass -from matplotlib.figure import Figure as MatplotlibFigure -from plotly.graph_objects import Figure as PlotlyFigure - -from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP -from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP - -from .model_comparison_plot_check import CheckModelComparisonPlot - - -@dataclass -class DummyPredictionCalculation: - """Dummy implementation of PredictionCalculationP for testing.""" - - df: pd.DataFrame - label_col: str - - def __call__(self, pred_data, error_func=None): - """Simple implementation that calculates pointwise errors.""" - result_df = self.df.copy() - - # Default error function (absolute error) - if error_func is None: - def error_func(y_true, y_pred): - return abs(y_true - y_pred) - - # Get label values - y_true = self.df[self.label_col] - - # Calculate errors for each prediction - for i, pred_series in enumerate(pred_data): - error_col = f"pred_{i}_error" - result_df[error_col] = error_func(y_true, pred_series) - - return result_df - - -@dataclass -class DummyModelComparisonPlot: - """Dummy implementation of ModelComparisonPlotP for testing.""" - - comparison_calc: PredictionCalculationP - - def plot_error_distribution( - self, - pred_data: List[pd.Series], - model_names: Optional[List[str]] = None, - figsize: Tuple[int, int] = (12, 8), - bins: int = 30, - **kwargs: Any, - ): - """Dummy implementation - returns None for protocol testing.""" - return None - - def plot_pairwise_comparison( - self, - pred_data: List[pd.Series], - model_names: Optional[List[str]] = None, - figsize: Tuple[int, int] = (10, 8), - **kwargs: Any, - ): - """Dummy implementation - returns None for protocol testing.""" - return None - - def plot_model_performance_summary( - self, - pred_data: List[pd.Series], - model_names: Optional[List[str]] = None, - metrics: Optional[List[str]] = None, - figsize: Tuple[int, int] = (12, 6), - **kwargs: Any, - ): - """Dummy implementation - returns None for protocol testing.""" - return None - - -@pytest.fixture -def instance_to_check() -> ModelComparisonPlotP: - """Provides an instance of ModelComparisonPlotP for protocol testing.""" - # Create a simple test DataFrame - df = pd.DataFrame({ - "feature1": np.random.rand(20), - "feature2": np.random.choice(["A", "B", "C"], 20), - "label": np.random.rand(20) - }) - comparison_calc = DummyPredictionCalculation(df, "label") - return DummyModelComparisonPlot(comparison_calc) - - -class TestModelComparisonPlot(CheckModelComparisonPlot): - """Test class for `ModelComparisonPlotP` protocol compliance.""" - - # Note: instance_to_check pytest fixture is implemented above \ No newline at end of file diff --git a/tests/base_architecture/model_comparison/test_model_comparison_protocol.py b/tests/base_architecture/model_comparison/test_model_comparison_protocol.py deleted file mode 100644 index 9cc336e..0000000 --- a/tests/base_architecture/model_comparison/test_model_comparison_protocol.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -import pandas as pd -import pytest -from dataclasses import dataclass - -from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP - -from .model_comparison_check import CheckPredictionCalculation - - -@dataclass -class DummyPredictionCalculation: - """Dummy implementation of PredictionCalculationP for testing.""" - - df: pd.DataFrame - label_col: str - - def __call__(self, pred_data, error_func=None): - """Simple implementation that calculates pointwise errors.""" - result_df = self.df.copy() - - # Default error function (absolute error) - if error_func is None: - def error_func(y_true, y_pred): - return abs(y_true - y_pred) - - # Get label values - y_true = self.df[self.label_col] - - # Calculate errors for each prediction - for i, pred_series in enumerate(pred_data): - error_col = f"pred_{i}_error" - result_df[error_col] = error_func(y_true, pred_series) - - return result_df - - -@pytest.fixture -def instance_to_check() -> PredictionCalculationP: - """Provides an instance of PredictionCalculationP for protocol testing.""" - # Create a simple test DataFrame - df = pd.DataFrame({ - "feature1": np.random.rand(20), - "feature2": np.random.choice(["A", "B", "C"], 20), - "label": np.random.rand(20) - }) - return DummyPredictionCalculation(df, "label") - - -class TestPredictionCalculation(CheckPredictionCalculation): - """Test class for `PredictionCalculationP` protocol compliance.""" - - # Note: instance_to_check pytest fixture is implemented above \ No newline at end of file From f9b8ef7fc25343053282e9c23a5859132b810b98 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:32:11 +0000 Subject: [PATCH 4/5] Fix linting issues found by pre-commit hooks --- tests/base_architecture/model_comparison/__init__.py | 2 +- .../model_comparison/model_comparison_check.py | 5 ++--- .../model_comparison/model_comparison_plot_check.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/base_architecture/model_comparison/__init__.py b/tests/base_architecture/model_comparison/__init__.py index 1ab3659..451c1fa 100644 --- a/tests/base_architecture/model_comparison/__init__.py +++ b/tests/base_architecture/model_comparison/__init__.py @@ -1 +1 @@ -"""Tests for model comparison protocols.""" \ No newline at end of file +"""Tests for model comparison protocols.""" diff --git a/tests/base_architecture/model_comparison/model_comparison_check.py b/tests/base_architecture/model_comparison/model_comparison_check.py index 6810f4b..9e820f7 100644 --- a/tests/base_architecture/model_comparison/model_comparison_check.py +++ b/tests/base_architecture/model_comparison/model_comparison_check.py @@ -1,5 +1,4 @@ import pandas as pd -import pytest from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP @@ -60,7 +59,7 @@ def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series: def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None: """Test that error calculations work correctly.""" # Create test data where we know the expected errors - n_samples = len(instance_to_check.df) + len(instance_to_check.df) label_values = instance_to_check.df[instance_to_check.label_col] # Create prediction that should have zero error @@ -80,4 +79,4 @@ def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> # Check that offset prediction has consistent non-zero error offset_errors = result["pred_1_error"] - assert (offset_errors > 0).all() \ No newline at end of file + assert (offset_errors > 0).all() diff --git a/tests/base_architecture/model_comparison/model_comparison_plot_check.py b/tests/base_architecture/model_comparison/model_comparison_plot_check.py index 8506166..7987df6 100644 --- a/tests/base_architecture/model_comparison/model_comparison_plot_check.py +++ b/tests/base_architecture/model_comparison/model_comparison_plot_check.py @@ -17,13 +17,13 @@ def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None: # The comparison_calc should implement PredictionCalculationP protocol assert hasattr(instance_to_check.comparison_calc, "df") assert hasattr(instance_to_check.comparison_calc, "label_col") - assert hasattr(instance_to_check.comparison_calc, "__call__") + assert callable(instance_to_check.comparison_calc) def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None: """Test the plot_error_distribution method.""" # Create test prediction data n_samples = len(instance_to_check.comparison_calc.df) - pred_data = [ + [ pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"), pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"), ] @@ -46,4 +46,4 @@ def test_plot_model_performance_summary(self, instance_to_check: ModelComparison """Test the plot_model_performance_summary method.""" # Test method exists and is callable assert hasattr(instance_to_check, "plot_model_performance_summary") - assert callable(instance_to_check.plot_model_performance_summary) \ No newline at end of file + assert callable(instance_to_check.plot_model_performance_summary) From 09cd69c07061f3c3c2a6ab1db8b4a3fa6c5851b6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 9 Jun 2025 19:33:57 +0000 Subject: [PATCH 5/5] Fix mypy type annotations and formatting issues --- .../model_comparison/model_comparison_check.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/base_architecture/model_comparison/model_comparison_check.py b/tests/base_architecture/model_comparison/model_comparison_check.py index 9e820f7..c8b5fe8 100644 --- a/tests/base_architecture/model_comparison/model_comparison_check.py +++ b/tests/base_architecture/model_comparison/model_comparison_check.py @@ -63,10 +63,12 @@ def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> label_values = instance_to_check.df[instance_to_check.label_col] # Create prediction that should have zero error - pred_exact = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact") + pred_exact: pd.Series = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact") # Create prediction with constant offset - pred_offset = pd.Series(label_values.values + 1, index=instance_to_check.df.index, name="pred_offset") + pred_offset: pd.Series = pd.Series( + label_values.to_numpy() + 1, index=instance_to_check.df.index, name="pred_offset" + ) pred_data = [pred_exact, pred_offset]