22
33namespace AiDotNet ;
44
5- public class Metrics : IMetrics
5+ public sealed class Metrics : IMetrics
66{
77 public double MeanSquaredError { get ; private set ; }
88 public double RootMeanSquaredError { get ; private set ; }
9-
9+ public double AdjustedR2 { get ; private set ; }
10+ public double R2 { get ; private set ; }
11+ public double PredictionsStandardError { get ; private set ; }
12+ public double PredictionsStandardDeviation { get ; private set ; }
13+ public double AverageStandardError { get ; private set ; }
14+ public double AverageStandardDeviation { get ; private set ; }
15+ public int DegreesOfFreedom { get ; private set ; }
16+
1017 private double [ ] _oosPredictions { get ; }
18+ private double _oosPredictionsAvg { get ; }
1119 private double [ ] _oosActualValues { get ; }
20+ private double _oosActualValuesAvg { get ; }
21+ private int _paramsCount { get ; }
22+ private int _sampleSize { get ; }
23+ private double _residualSumOfSquares { get ; set ; }
24+ private double _totalSumOfSquares { get ; set ; }
1225
13- public Metrics ( double [ ] OosPredictions , double [ ] OosActualValues )
26+ public Metrics ( double [ ] OosPredictions , double [ ] OosActualValues , int paramCount )
1427 {
1528 _oosPredictions = OosPredictions ;
29+ _oosPredictionsAvg = _oosPredictions . Average ( ) ;
1630 _oosActualValues = OosActualValues ;
31+ _oosActualValuesAvg = _oosActualValues . Average ( ) ;
32+ _paramsCount = paramCount ;
33+ _sampleSize = _oosPredictions . Length ;
1734
35+ DegreesOfFreedom = CalculateDegreesOfFreedom ( ) ;
36+ R2 = CalculateR2 ( ) ;
37+ AdjustedR2 = CalculateAdjustedR2 ( R2 ) ;
38+ AverageStandardDeviation = CalculateAverageStandardDeviation ( ) ;
39+ PredictionsStandardDeviation = CalculatePredictionStandardDeviation ( ) ;
40+ AverageStandardError = CalculateAverageStandardError ( ) ;
41+ PredictionsStandardError = CalculatePredictionStandardError ( ) ;
1842 MeanSquaredError = CalculateMeanSquaredError ( ) ;
1943 RootMeanSquaredError = CalculateRootMeanSquaredError ( ) ;
2044 }
2145
22- internal sealed override double CalculateMeanSquaredError ( )
46+ internal override double CalculateMeanSquaredError ( )
47+ {
48+ return _residualSumOfSquares / _sampleSize ;
49+ }
50+
51+ internal override double CalculateRootMeanSquaredError ( )
52+ {
53+ return MeanSquaredError >= 0 ? Math . Sqrt ( MeanSquaredError ) : 0 ;
54+ }
55+
56+ internal override double CalculateR2 ( )
2357 {
24- double sum = 0 ;
25- for ( var i = 0 ; i < _oosPredictions . Length ; i ++ )
58+ double residualSumSquares = 0 , totalSumSquares = 0 ;
59+ for ( int i = 0 ; i < _sampleSize ; i ++ )
2660 {
27- sum += Math . Pow ( _oosActualValues [ i ] - _oosPredictions [ i ] , 2 ) ;
61+ residualSumSquares += Math . Pow ( _oosActualValues [ i ] - _oosPredictions [ i ] , 2 ) ;
62+ totalSumSquares += Math . Pow ( _oosActualValues [ i ] - _oosActualValuesAvg , 2 ) ;
2863 }
2964
30- return sum / _oosPredictions . Length ;
65+ // We are saving these values for later reuse
66+ _residualSumOfSquares = residualSumSquares ;
67+ _totalSumOfSquares = totalSumSquares ;
68+
69+ return _totalSumOfSquares != 0 ? 1 - ( _residualSumOfSquares / _totalSumOfSquares ) : 0 ;
3170 }
3271
33- internal sealed override double CalculateRootMeanSquaredError ( )
72+ internal override double CalculateAdjustedR2 ( double r2 )
3473 {
35- return MeanSquaredError >= 0 ? Math . Sqrt ( MeanSquaredError ) : 0 ;
74+ return _sampleSize != 1 && DegreesOfFreedom != 1 ? 1 - ( 1 - Math . Pow ( r2 , 2 ) ) * ( _sampleSize - 1 ) / ( DegreesOfFreedom - 1 ) : 0 ;
75+ }
76+
77+ internal override double CalculateAverageStandardError ( )
78+ {
79+ return AverageStandardDeviation / Math . Sqrt ( _sampleSize ) ;
80+ }
81+
82+ internal override double CalculatePredictionStandardError ( )
83+ {
84+ return PredictionsStandardDeviation / Math . Sqrt ( _sampleSize ) ;
85+ }
86+
87+ private static double CalculateStandardDeviation ( double avgSumSquares )
88+ {
89+ return avgSumSquares >= 0 ? Math . Sqrt ( avgSumSquares ) : 0 ;
90+ }
91+
92+ internal override double CalculateAverageStandardDeviation ( )
93+ {
94+ var avgSumSquares = _totalSumOfSquares / _sampleSize ;
95+
96+ return CalculateStandardDeviation ( avgSumSquares ) ;
97+ }
98+
99+ internal override double CalculatePredictionStandardDeviation ( )
100+ {
101+ var avgSumSquares = _residualSumOfSquares / _sampleSize ;
102+
103+ return CalculateStandardDeviation ( avgSumSquares ) ;
104+ }
105+
106+ internal override int CalculateDegreesOfFreedom ( )
107+ {
108+ return _sampleSize - _paramsCount ;
36109 }
37110}
0 commit comments