11from typing import Dict
22from typing import List
3+ from typing import Optional
34
45from evidently .future .container import MetricContainer
6+ from evidently .future .metric_types import MeanStdMetricTests
57from evidently .future .metric_types import Metric
68from evidently .future .metric_types import MetricId
79from evidently .future .metric_types import MetricResult
10+ from evidently .future .metric_types import SingleValueMetricTests
811from evidently .future .metrics import MAE
912from evidently .future .metrics import MAPE
1013from evidently .future .metrics import RMSE
@@ -31,19 +34,31 @@ def __init__(
3134 pred_actual_plot : bool = False ,
3235 error_plot : bool = False ,
3336 error_distr : bool = False ,
37+ mean_error_tests : Optional [MeanStdMetricTests ] = None ,
38+ mape_tests : Optional [MeanStdMetricTests ] = None ,
39+ rmse_tests : SingleValueMetricTests = None ,
40+ mae_tests : Optional [MeanStdMetricTests ] = None ,
41+ r2score_tests : SingleValueMetricTests = None ,
42+ abs_max_error_tests : SingleValueMetricTests = None ,
3443 ):
3544 self ._pred_actual_plot = pred_actual_plot
3645 self ._error_plot = error_plot
3746 self ._error_distr = error_distr
47+ self ._mean_error_tests = mean_error_tests or MeanStdMetricTests ()
48+ self ._mape_tests = mape_tests or MeanStdMetricTests ()
49+ self ._rmse_tests = rmse_tests
50+ self ._mae_tests = mae_tests or MeanStdMetricTests ()
51+ self ._r2score_tests = r2score_tests
52+ self ._abs_max_error_tests = abs_max_error_tests
3853
3954 def generate_metrics (self , context : Context ) -> List [Metric ]:
4055 return [
41- MeanError (),
42- MAPE (),
43- RMSE (),
44- MAE (),
45- R2Score (),
46- AbsMaxError (),
56+ MeanError (mean_tests = self . _mean_error_tests . mean , std_tests = self . _mean_error_tests . std ),
57+ MAPE (mean_tests = self . _mape_tests . mean , std_tests = self . _mape_tests . std ),
58+ RMSE (tests = self . _rmse_tests ),
59+ MAE (mean_tests = self . _mae_tests . mean , std_tests = self . _mae_tests . std ),
60+ R2Score (tests = self . _r2score_tests ),
61+ AbsMaxError (tests = self . _abs_max_error_tests ),
4762 ]
4863
4964 def render (self , context : Context , results : Dict [MetricId , MetricResult ]) -> List [BaseWidgetInfo ]:
@@ -72,11 +87,21 @@ def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> Lis
7287
7388
7489class RegressionDummyQuality (MetricContainer ):
90+ def __init__ (
91+ self ,
92+ mae_tests : SingleValueMetricTests = None ,
93+ mape_tests : SingleValueMetricTests = None ,
94+ rmse_tests : SingleValueMetricTests = None ,
95+ ):
96+ self ._mae_tests = mae_tests
97+ self ._mape_tests = mape_tests
98+ self ._rmse_tests = rmse_tests
99+
75100 def generate_metrics (self , context : Context ) -> List [Metric ]:
76101 return [
77- DummyMAE (),
78- DummyMAPE (),
79- DummyRMSE (),
102+ DummyMAE (tests = self . _mae_tests ),
103+ DummyMAPE (tests = self . _mape_tests ),
104+ DummyRMSE (tests = self . _rmse_tests ),
80105 ]
81106
82107 def render (self , context : Context , results : Dict [MetricId , MetricResult ]) -> List [BaseWidgetInfo ]:
@@ -91,21 +116,51 @@ def render(self, context: Context, results: Dict[MetricId, MetricResult]) -> Lis
91116
92117
93118class RegressionPreset (MetricContainer ):
94- def __init__ (self ):
119+ _quality : Optional [RegressionQuality ] = None
120+
121+ def __init__ (
122+ self ,
123+ mean_error_tests : Optional [MeanStdMetricTests ] = None ,
124+ mape_tests : Optional [MeanStdMetricTests ] = None ,
125+ rmse_tests : SingleValueMetricTests = None ,
126+ mae_tests : Optional [MeanStdMetricTests ] = None ,
127+ r2score_tests : SingleValueMetricTests = None ,
128+ abs_max_error_tests : SingleValueMetricTests = None ,
129+ ):
95130 self ._quality = None
131+ self ._mean_error_tests = mean_error_tests or MeanStdMetricTests ()
132+ self ._mape_tests = mape_tests or MeanStdMetricTests ()
133+ self ._rmse_tests = rmse_tests
134+ self ._mae_tests = mae_tests or MeanStdMetricTests ()
135+ self ._r2score_tests = r2score_tests
136+ self ._abs_max_error_tests = abs_max_error_tests
96137
97138 def generate_metrics (self , context : Context ) -> List [Metric ]:
98- self ._quality = RegressionQuality (True , True , True )
139+ self ._quality = RegressionQuality (
140+ True ,
141+ True ,
142+ True ,
143+ self ._mean_error_tests ,
144+ self ._mape_tests ,
145+ self ._rmse_tests ,
146+ self ._mae_tests ,
147+ self ._r2score_tests ,
148+ self ._abs_max_error_tests ,
149+ )
99150 return self ._quality .metrics (context ) + [
100- MAPE (),
101- AbsMaxError (),
102- R2Score (),
151+ MAPE (mean_tests = self . _mape_tests . mean , std_tests = self . _mape_tests . std ),
152+ AbsMaxError (tests = self . _abs_max_error_tests ),
153+ R2Score (tests = self . _r2score_tests ),
103154 ]
104155
105156 def render (self , context : "Context" , results : Dict [MetricId , MetricResult ]) -> List [BaseWidgetInfo ]:
157+ if self ._quality is None :
158+ raise ValueError ("No _quality set in preset, something went wrong." )
106159 return (
107160 self ._quality .render (context , results )
108- + context .get_metric_result (MAPE ()).widget
109- + context .get_metric_result (AbsMaxError ()).widget
110- + context .get_metric_result (R2Score ()).widget
161+ + context .get_metric_result (
162+ MAPE (mean_tests = self ._mape_tests .mean , std_tests = self ._mape_tests .std ),
163+ ).widget
164+ + context .get_metric_result (AbsMaxError (tests = self ._abs_max_error_tests )).widget
165+ + context .get_metric_result (R2Score (tests = self ._r2score_tests )).widget
111166 )
0 commit comments