generated from DanielAvdar/python-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Design protocol classes and update architecture for model comparison analysis #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Copilot
wants to merge
5
commits into
main
Choose a base branch
from
copilot/fix-109
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
85cdb51
Initial plan for issue
Copilot 1e9ea49
Implement model comparison protocol classes and tests
Copilot 69bc74a
Remove test files and fix protocol method style consistency
Copilot f9b8ef7
Fix linting issues found by pre-commit hooks
Copilot 09cd69c
Fix mypy type annotations and formatting issues
Copilot File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions
126
tab_right/base_architecture/model_comparison_plot_protocols.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -0,0 +1,126 @@ | ||||
| """Protocols for model comparison visualization in tab-right. | ||||
|
|
||||
| This module defines protocols for visualizing model comparison results. It provides | ||||
| interfaces for creating visualizations that compare multiple prediction datasets | ||||
| against true labels. | ||||
| """ | ||||
|
|
||||
| from dataclasses import dataclass | ||||
| from typing import Any, List, Optional, Protocol, Tuple, Union, runtime_checkable | ||||
|
|
||||
| import matplotlib.pyplot as plt | ||||
| import pandas as pd | ||||
| import plotly.graph_objects as go | ||||
|
|
||||
| from .model_comparison_protocols import PredictionCalculationP | ||||
|
|
||||
| Figure = Union[go.Figure, plt.Figure] | ||||
|
|
||||
|
|
||||
| @runtime_checkable | ||||
| @dataclass | ||||
| class ModelComparisonPlotP(Protocol): | ||||
| """Protocol for model comparison visualization implementations. | ||||
|
|
||||
| This protocol defines the interface for visualizing comparison results | ||||
| between multiple predictions and true labels. | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| comparison_calc : PredictionCalculationP | ||||
| An implementation of PredictionCalculationP that provides the comparison | ||||
| metrics to visualize. | ||||
|
|
||||
| """ | ||||
|
|
||||
| comparison_calc: PredictionCalculationP | ||||
|
|
||||
| def plot_error_distribution( | ||||
| self, | ||||
| pred_data: List[pd.Series], | ||||
| model_names: Optional[List[str]] = None, | ||||
| figsize: Tuple[int, int] = (12, 8), | ||||
| bins: int = 30, | ||||
| **kwargs: Any, | ||||
| ) -> Figure: | ||||
| """Create a visualization comparing error distributions across models. | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| pred_data : List[pd.Series] | ||||
| List of prediction Series to compare against the label. | ||||
| model_names : Optional[List[str]], default None | ||||
| Names for each model. If None, uses default names like "Model 0", "Model 1", etc. | ||||
| figsize : Tuple[int, int], default (12, 8) | ||||
| Figure size as (width, height) in inches. | ||||
| bins : int, default 30 | ||||
| Number of bins for histogram visualization. | ||||
| **kwargs : Any | ||||
| Additional parameters for the plotting implementation. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| Figure | ||||
| A figure object containing the error distribution comparison. | ||||
|
|
||||
| """ | ||||
| ... | ||||
|
|
||||
| def plot_pairwise_comparison( | ||||
| self, | ||||
| pred_data: List[pd.Series], | ||||
| model_names: Optional[List[str]] = None, | ||||
| figsize: Tuple[int, int] = (10, 8), | ||||
| **kwargs: Any, | ||||
| ) -> Figure: | ||||
| """Create a pairwise comparison plot between models. | ||||
|
|
||||
| Parameters | ||||
| ---------- | ||||
| pred_data : List[pd.Series] | ||||
| List of prediction Series to compare against the label. | ||||
| model_names : Optional[List[str]], default None | ||||
| Names for each model. If None, uses default names. | ||||
| figsize : Tuple[int, int], default (10, 8) | ||||
| Figure size as (width, height) in inches. | ||||
| **kwargs : Any | ||||
| Additional parameters for the plotting implementation. | ||||
|
|
||||
| Returns | ||||
| ------- | ||||
| Figure | ||||
| A figure object containing the pairwise comparison visualization. | ||||
|
|
||||
| """ | ||||
| ... | ||||
|
||||
| ... |
Outdated
Collaborator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggested change
| ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| """Protocol definitions for model comparison analysis in tab-right. | ||
|
|
||
| This module defines protocol classes for model comparison analysis, | ||
| including interfaces for prediction calculation and comparison between multiple models. | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Callable, List, Optional, Protocol, runtime_checkable | ||
|
|
||
| import pandas as pd | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| @dataclass | ||
| class PredictionCalculationP(Protocol): | ||
| """Protocol for prediction calculation implementations. | ||
|
|
||
| This protocol defines the interface for calculating pointwise errors | ||
| between multiple sets of predictions and true labels. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| df : pd.DataFrame | ||
| A DataFrame containing the data for analysis. | ||
| label_col : str | ||
| Column name for the true target values. | ||
|
|
||
| """ | ||
|
|
||
| df: pd.DataFrame | ||
| label_col: str | ||
|
|
||
| def __call__( | ||
| self, | ||
| pred_data: List[pd.Series], | ||
| error_func: Optional[Callable[[pd.Series, pd.Series], pd.Series]] = None, | ||
| ) -> pd.DataFrame: | ||
| """Calculate pointwise errors for multiple prediction series against the label. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pred_data : List[pd.Series] | ||
| List of prediction Series to compare against the label. | ||
| Each Series should have the same index as the DataFrame. | ||
| error_func : Optional[Callable[[pd.Series, pd.Series], pd.Series]], default None | ||
| Function for calculating pairwise error between predictions and labels. | ||
| If None, defaults to a standard metric (e.g., absolute error for regression, | ||
| 0/1 loss for classification). | ||
| Function signature: error_func(y_true, y_pred) -> pd.Series | ||
|
|
||
| Returns | ||
| ------- | ||
| pd.DataFrame | ||
| DataFrame containing pointwise errors for each prediction series. | ||
| Expected columns: | ||
| - Original DataFrame columns (as context) | ||
| - `{label_col}`: The true label values | ||
| - `pred_0_error`, `pred_1_error`, ...: Pointwise errors for each prediction series | ||
| - `model_id`: Optional identifier for the prediction series | ||
|
|
||
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Tests for model comparison protocols.""" |
83 changes: 83 additions & 0 deletions
83
tests/base_architecture/model_comparison/model_comparison_check.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from tab_right.base_architecture.model_comparison_protocols import PredictionCalculationP | ||
|
|
||
| from ..base_protocols_check import CheckProtocols | ||
|
|
||
|
|
||
| class CheckPredictionCalculation(CheckProtocols): | ||
| """Class for checking compliance of `PredictionCalculationP` protocol.""" | ||
|
|
||
| # Use the protocol type directly | ||
| class_to_check = PredictionCalculationP | ||
|
|
||
| def test_attributes(self, instance_to_check: PredictionCalculationP) -> None: | ||
| """Test attributes of the instance to ensure compliance.""" | ||
| assert hasattr(instance_to_check, "df") | ||
| assert hasattr(instance_to_check, "label_col") | ||
| assert isinstance(instance_to_check.df, pd.DataFrame) | ||
| assert isinstance(instance_to_check.label_col, str) | ||
| assert instance_to_check.label_col in instance_to_check.df.columns | ||
|
|
||
| def test_call_method(self, instance_to_check: PredictionCalculationP) -> None: | ||
| """Test the __call__ method of the instance.""" | ||
| # Create test prediction data | ||
| n_samples = len(instance_to_check.df) | ||
| pred_data = [ | ||
| pd.Series(range(n_samples), index=instance_to_check.df.index, name="pred_0"), | ||
| pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.df.index, name="pred_1"), | ||
| ] | ||
|
|
||
| # Test with default error function | ||
| result = instance_to_check(pred_data) | ||
| assert isinstance(result, pd.DataFrame) | ||
|
|
||
| # Check that the result contains original DataFrame columns | ||
| for col in instance_to_check.df.columns: | ||
| assert col in result.columns | ||
|
|
||
| # Check that error columns are present | ||
| assert "pred_0_error" in result.columns | ||
| assert "pred_1_error" in result.columns | ||
|
|
||
| # Test with custom error function | ||
| def custom_error(y_true: pd.Series, y_pred: pd.Series) -> pd.Series: | ||
| return abs(y_true - y_pred) | ||
|
|
||
| result_custom = instance_to_check(pred_data, error_func=custom_error) | ||
| assert isinstance(result_custom, pd.DataFrame) | ||
| assert "pred_0_error" in result_custom.columns | ||
| assert "pred_1_error" in result_custom.columns | ||
|
|
||
| # Test with single prediction | ||
| single_pred = [pred_data[0]] | ||
| result_single = instance_to_check(single_pred) | ||
| assert isinstance(result_single, pd.DataFrame) | ||
| assert "pred_0_error" in result_single.columns | ||
| assert "pred_1_error" not in result_single.columns | ||
|
|
||
| def test_error_calculations(self, instance_to_check: PredictionCalculationP) -> None: | ||
| """Test that error calculations work correctly.""" | ||
| # Create test data where we know the expected errors | ||
| n_samples = len(instance_to_check.df) | ||
| label_values = instance_to_check.df[instance_to_check.label_col] | ||
|
|
||
| # Create prediction that should have zero error | ||
| pred_exact = pd.Series(label_values.values, index=instance_to_check.df.index, name="pred_exact") | ||
|
|
||
| # Create prediction with constant offset | ||
| pred_offset = pd.Series(label_values.values + 1, index=instance_to_check.df.index, name="pred_offset") | ||
|
|
||
| pred_data = [pred_exact, pred_offset] | ||
|
|
||
| # Test with default error function | ||
| result = instance_to_check(pred_data) | ||
|
|
||
| # Check that exact prediction has zero error (or very close to zero) | ||
| exact_errors = result["pred_0_error"] | ||
| assert (exact_errors < 1e-10).all() or (exact_errors == 0).all() | ||
|
|
||
| # Check that offset prediction has consistent non-zero error | ||
| offset_errors = result["pred_1_error"] | ||
| assert (offset_errors > 0).all() |
49 changes: 49 additions & 0 deletions
49
tests/base_architecture/model_comparison/model_comparison_plot_check.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| import pandas as pd | ||
|
|
||
| from tab_right.base_architecture.model_comparison_plot_protocols import ModelComparisonPlotP | ||
|
|
||
| from ..base_protocols_check import CheckProtocols | ||
|
|
||
|
|
||
| class CheckModelComparisonPlot(CheckProtocols): | ||
| """Class for checking compliance of `ModelComparisonPlotP` protocol.""" | ||
|
|
||
| # Use the protocol type directly | ||
| class_to_check = ModelComparisonPlotP | ||
|
|
||
| def test_attributes(self, instance_to_check: ModelComparisonPlotP) -> None: | ||
| """Test attributes of the instance to ensure compliance.""" | ||
| assert hasattr(instance_to_check, "comparison_calc") | ||
| # The comparison_calc should implement PredictionCalculationP protocol | ||
| assert hasattr(instance_to_check.comparison_calc, "df") | ||
| assert hasattr(instance_to_check.comparison_calc, "label_col") | ||
| assert hasattr(instance_to_check.comparison_calc, "__call__") | ||
|
|
||
| def test_plot_error_distribution(self, instance_to_check: ModelComparisonPlotP) -> None: | ||
| """Test the plot_error_distribution method.""" | ||
| # Create test prediction data | ||
| n_samples = len(instance_to_check.comparison_calc.df) | ||
| pred_data = [ | ||
| pd.Series(range(n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_0"), | ||
| pd.Series(range(n_samples, 2 * n_samples), index=instance_to_check.comparison_calc.df.index, name="pred_1"), | ||
| ] | ||
|
|
||
| # Test method exists and is callable | ||
| assert hasattr(instance_to_check, "plot_error_distribution") | ||
| assert callable(instance_to_check.plot_error_distribution) | ||
|
|
||
| # Test that method can be called with default parameters | ||
| # Note: We're not testing the actual plotting since these are protocols | ||
| # The implementation would be tested separately | ||
|
|
||
| def test_plot_pairwise_comparison(self, instance_to_check: ModelComparisonPlotP) -> None: | ||
| """Test the plot_pairwise_comparison method.""" | ||
| # Test method exists and is callable | ||
| assert hasattr(instance_to_check, "plot_pairwise_comparison") | ||
| assert callable(instance_to_check.plot_pairwise_comparison) | ||
|
|
||
| def test_plot_model_performance_summary(self, instance_to_check: ModelComparisonPlotP) -> None: | ||
| """Test the plot_model_performance_summary method.""" | ||
| # Test method exists and is callable | ||
| assert hasattr(instance_to_check, "plot_model_performance_summary") | ||
| assert callable(instance_to_check.plot_model_performance_summary) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.