Skip to content

Commit 398b941

Browse files
authored
Add test params to regression presets. (#1493)
* Add test arguments to regression presets.
1 parent 298e75d commit 398b941

File tree

4 files changed

+114
-17
lines changed

4 files changed

+114
-17
lines changed

src/evidently/future/metric_types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import typing_inspect
2525

26+
from evidently._pydantic_compat import BaseModel
2627
from evidently.future._utils import not_implemented
2728
from evidently.future.datasets import Dataset
2829
from evidently.metric_results import Label
@@ -988,6 +989,11 @@ def run_test(
988989
)
989990

990991

992+
class MeanStdMetricTests(BaseModel):
993+
mean: SingleValueMetricTests = None
994+
std: SingleValueMetricTests = None
995+
996+
991997
class MeanStdMetric(Metric["MeanStdCalculation"]):
992998
mean_tests: SingleValueMetricTests = None
993999
std_tests: SingleValueMetricTests = None

src/evidently/future/presets/regression.py

Lines changed: 72 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from typing import Dict
22
from typing import List
3+
from typing import Optional
34

45
from evidently.future.container import MetricContainer
6+
from evidently.future.metric_types import MeanStdMetricTests
57
from evidently.future.metric_types import Metric
68
from evidently.future.metric_types import MetricId
79
from evidently.future.metric_types import MetricResult
10+
from evidently.future.metric_types import SingleValueMetricTests
811
from evidently.future.metrics import MAE
912
from evidently.future.metrics import MAPE
1013
from evidently.future.metrics import RMSE
@@ -31,19 +34,31 @@ def __init__(
3134
pred_actual_plot: bool = False,
3235
error_plot: bool = False,
3336
error_distr: bool = False,
37+
mean_error_tests: Optional[MeanStdMetricTests] = None,
38+
mape_tests: Optional[MeanStdMetricTests] = None,
39+
rmse_tests: SingleValueMetricTests = None,
40+
mae_tests: Optional[MeanStdMetricTests] = None,
41+
r2score_tests: SingleValueMetricTests = None,
42+
abs_max_error_tests: SingleValueMetricTests = None,
3443
):
3544
self._pred_actual_plot = pred_actual_plot
3645
self._error_plot = error_plot
3746
self._error_distr = error_distr
47+
self._mean_error_tests = mean_error_tests or MeanStdMetricTests()
48+
self._mape_tests = mape_tests or MeanStdMetricTests()
49+
self._rmse_tests = rmse_tests
50+
self._mae_tests = mae_tests or MeanStdMetricTests()
51+
self._r2score_tests = r2score_tests
52+
self._abs_max_error_tests = abs_max_error_tests
3853

3954
def generate_metrics(self, context: Context) -> List[Metric]:
4055
return [
41-
MeanError(),
42-
MAPE(),
43-
RMSE(),
44-
MAE(),
45-
R2Score(),
46-
AbsMaxError(),
56+
MeanError(mean_tests=self._mean_error_tests.mean, std_tests=self._mean_error_tests.std),
57+
MAPE(mean_tests=self._mape_tests.mean, std_tests=self._mape_tests.std),
58+
RMSE(tests=self._rmse_tests),
59+
MAE(mean_tests=self._mae_tests.mean, std_tests=self._mae_tests.std),
60+
R2Score(tests=self._r2score_tests),
61+
AbsMaxError(tests=self._abs_max_error_tests),
4762
]
4863

4964
def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
@@ -72,11 +87,21 @@ def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> Lis
7287

7388

7489
class RegressionDummyQuality(MetricContainer):
90+
def __init__(
91+
self,
92+
mae_tests: SingleValueMetricTests = None,
93+
mape_tests: SingleValueMetricTests = None,
94+
rmse_tests: SingleValueMetricTests = None,
95+
):
96+
self._mae_tests = mae_tests
97+
self._mape_tests = mape_tests
98+
self._rmse_tests = rmse_tests
99+
75100
def generate_metrics(self, context: Context) -> List[Metric]:
76101
return [
77-
DummyMAE(),
78-
DummyMAPE(),
79-
DummyRMSE(),
102+
DummyMAE(tests=self._mae_tests),
103+
DummyMAPE(tests=self._mape_tests),
104+
DummyRMSE(tests=self._rmse_tests),
80105
]
81106

82107
def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
@@ -91,21 +116,51 @@ def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> Lis
91116

92117

93118
class RegressionPreset(MetricContainer):
94-
def __init__(self):
119+
_quality: Optional[RegressionQuality] = None
120+
121+
def __init__(
122+
self,
123+
mean_error_tests: Optional[MeanStdMetricTests] = None,
124+
mape_tests: Optional[MeanStdMetricTests] = None,
125+
rmse_tests: SingleValueMetricTests = None,
126+
mae_tests: Optional[MeanStdMetricTests] = None,
127+
r2score_tests: SingleValueMetricTests = None,
128+
abs_max_error_tests: SingleValueMetricTests = None,
129+
):
95130
self._quality = None
131+
self._mean_error_tests = mean_error_tests or MeanStdMetricTests()
132+
self._mape_tests = mape_tests or MeanStdMetricTests()
133+
self._rmse_tests = rmse_tests
134+
self._mae_tests = mae_tests or MeanStdMetricTests()
135+
self._r2score_tests = r2score_tests
136+
self._abs_max_error_tests = abs_max_error_tests
96137

97138
def generate_metrics(self, context: Context) -> List[Metric]:
98-
self._quality = RegressionQuality(True, True, True)
139+
self._quality = RegressionQuality(
140+
True,
141+
True,
142+
True,
143+
self._mean_error_tests,
144+
self._mape_tests,
145+
self._rmse_tests,
146+
self._mae_tests,
147+
self._r2score_tests,
148+
self._abs_max_error_tests,
149+
)
99150
return self._quality.metrics(context) + [
100-
MAPE(),
101-
AbsMaxError(),
102-
R2Score(),
151+
MAPE(mean_tests=self._mape_tests.mean, std_tests=self._mape_tests.std),
152+
AbsMaxError(tests=self._abs_max_error_tests),
153+
R2Score(tests=self._r2score_tests),
103154
]
104155

105156
def render(self, context: "Context", results: Dict[MetricId, MetricResult]) -> List[BaseWidgetInfo]:
157+
if self._quality is None:
158+
raise ValueError("No _quality set in preset, something went wrong.")
106159
return (
107160
self._quality.render(context, results)
108-
+ context.get_metric_result(MAPE()).widget
109-
+ context.get_metric_result(AbsMaxError()).widget
110-
+ context.get_metric_result(R2Score()).widget
161+
+ context.get_metric_result(
162+
MAPE(mean_tests=self._mape_tests.mean, std_tests=self._mape_tests.std),
163+
).widget
164+
+ context.get_metric_result(AbsMaxError(tests=self._abs_max_error_tests)).widget
165+
+ context.get_metric_result(R2Score(tests=self._r2score_tests)).widget
111166
)

tests/future/presets/__init__.py

Whitespace-only changes.

tests/future/presets/regression.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from evidently.future.datasets import DataDefinition
5+
from evidently.future.datasets import Dataset
6+
from evidently.future.datasets import Regression
7+
from evidently.future.metric_types import MeanStdMetricTests
8+
from evidently.future.presets import RegressionQuality
9+
from evidently.future.report import Report
10+
from evidently.future.tests import lt
11+
12+
13+
@pytest.mark.parametrize(
14+
"preset,expected_tests",
15+
[
16+
(RegressionQuality(), 0),
17+
(RegressionQuality(mean_error_tests=MeanStdMetricTests(mean=[lt(0.1)])), 1),
18+
(RegressionQuality(mean_error_tests=MeanStdMetricTests(std=[lt(0.1)])), 1),
19+
(RegressionQuality(mae_tests=MeanStdMetricTests(mean=[lt(0.1)])), 1),
20+
(RegressionQuality(mae_tests=MeanStdMetricTests(std=[lt(0.1)])), 1),
21+
(RegressionQuality(mape_tests=MeanStdMetricTests(mean=[lt(0.1)])), 1),
22+
(RegressionQuality(mape_tests=MeanStdMetricTests(std=[lt(0.1)])), 1),
23+
(RegressionQuality(rmse_tests=[lt(0.1)]), 1),
24+
(RegressionQuality(r2score_tests=[lt(0.1)]), 1),
25+
(RegressionQuality(abs_max_error_tests=[lt(0.1)]), 1),
26+
],
27+
)
28+
def test_regression_quality_preset_tests(preset, expected_tests):
29+
report = Report([preset])
30+
dataset = Dataset.from_pandas(
31+
pd.DataFrame(data=dict(target=[1, 2, 3, 4, 5], prediction=[0, 1, 2, 3, 4])),
32+
data_definition=DataDefinition(regression=[Regression()]),
33+
)
34+
snapshot = report.run(dataset)
35+
snapshot_data = snapshot.dict()
36+
assert len(snapshot_data["tests"]) == expected_tests

0 commit comments

Comments
 (0)