@@ -40,38 +40,77 @@ static bool check_if_approximately_zero(TReal a, TReal tolerance = std::numeric_
4040 return false ;
4141}
4242
43+ VectorXd calculate_gaussian_errors (const VectorXd &y,const VectorXd &predicted)
44+ {
45+ VectorXd errors{y-predicted};
46+ errors=errors.array ()*errors.array ();
47+ return errors;
48+ }
49+
50+ VectorXd calculate_logit_errors (const VectorXd &y,const VectorXd &predicted)
51+ {
52+ VectorXd errors{-y.array () * predicted.array ().log () - (1.0 -y.array ()).array () * (1.0 -predicted.array ()).log ()};
53+ return errors;
54+ }
55+
56+ VectorXd calculate_poisson_errors (const VectorXd &y,const VectorXd &predicted)
57+ {
58+ VectorXd errors{predicted.array () - y.array ()*predicted.array ().log ()};
59+ return errors;
60+ }
61+
62+ VectorXd calculate_gamma_errors (const VectorXd &y,const VectorXd &predicted)
63+ {
64+ VectorXd errors{predicted.array ().log () - y.array ().log () + y.array ()/predicted.array ()-1 };
65+ return errors;
66+ }
67+
68+ VectorXd calculate_poissongamma_errors (const VectorXd &y,const VectorXd &predicted)
69+ {
70+ VectorXd errors{y.array ().pow (0.5 ).array () / (-0.25 ) + y.array ()*predicted.array ().pow (-0.5 ) / 0.5 + predicted.array ().pow (0.5 ) / 0.5 };
71+ return errors;
72+ }
73+
4374// Computes errors (for each observation) based on error metric for a vector
44- VectorXd calculate_errors (const VectorXd &y,const VectorXd &predicted,const VectorXd &sample_weight=VectorXd(0 ),bool loss_function_mse=true )
75+ VectorXd calculate_errors (const VectorXd &y,const VectorXd &predicted,const VectorXd &sample_weight=VectorXd(0 ),const std::string &family="gaussian" )
4576{
4677 // Error per observation before adjustment for sample weights
47- VectorXd residuals{y-predicted};
48- if (loss_function_mse)
49- residuals=residuals.array ()*residuals.array ();
50- else
51- residuals=residuals.cwiseAbs ();
52-
78+ VectorXd errors;
79+ if (family==" gaussian" )
80+ errors=calculate_gaussian_errors (y,predicted);
81+ else if (family==" logit" )
82+ errors=calculate_logit_errors (y,predicted);
83+ else if (family==" poisson" )
84+ errors=calculate_poisson_errors (y,predicted);
85+ else if (family==" gamma" )
86+ errors=calculate_gamma_errors (y,predicted);
87+ else if (family==" poissongamma" )
88+ errors=calculate_poissongamma_errors (y,predicted);
5389 // Adjusting for sample weights if specified
5490 if (sample_weight.size ()>0 )
55- residuals=residuals .array ()*sample_weight.array ();
91+ errors=errors .array ()*sample_weight.array ();
5692
57- return residuals;
93+ return errors;
94+ }
95+
96+ double calculate_gaussian_error_one_observation (double y,double predicted)
97+ {
98+ double error{y-predicted};
99+ error=error*error;
100+ return error;
58101}
59102
60103// Computes error for one observation based on error metric
61- double calculate_error_one_observation (double y,double predicted,double sample_weight=NAN_DOUBLE, bool loss_function_mse= true )
104+ double calculate_error_one_observation (double y,double predicted,double sample_weight=NAN_DOUBLE)
62105{
63106 // Error per observation before adjustment for sample weights
64- double residual{y-predicted};
65- if (loss_function_mse)
66- residual=residual*residual;
67- else
68- residual=abs (residual);
69-
107+ double error{calculate_gaussian_error_one_observation (y,predicted)};
108+
70109 // Adjusting for sample weights if specified
71110 if (!std::isnan (sample_weight))
72- residual=residual *sample_weight;
73-
74- return residual ;
111+ error=error *sample_weight;
112+
113+ return error ;
75114}
76115
77116// Computes overall error based on errors from calculate_errors(), returning one value
@@ -88,6 +127,66 @@ double calculate_error(const VectorXd &errors,const VectorXd &sample_weight=Vect
88127 return error;
89128}
90129
130+ VectorXd transform_zero_to_negative (const VectorXd &linear_predictor)
131+ {
132+ VectorXd transformed_linear_predictor{linear_predictor};
133+ for (size_t i = 0 ; i < static_cast <size_t >(transformed_linear_predictor.rows ()); ++i)
134+ {
135+ bool row_is_not_negative{std::isgreaterequal (transformed_linear_predictor[i],0.0 )};
136+ if (row_is_not_negative)
137+ transformed_linear_predictor[i]=SMALL_NEGATIVE_VALUE;
138+ }
139+ return transformed_linear_predictor;
140+ }
141+
142+ VectorXd transform_linear_predictor_to_predictions (const VectorXd &linear_predictor, const std::string &family=" gaussian" )
143+ {
144+ if (family==" gaussian" )
145+ return linear_predictor;
146+ else if (family==" logit" )
147+ {
148+ VectorXd exp_of_linear_predictor{linear_predictor.array ().exp ()};
149+ return exp_of_linear_predictor.array () / (1.0 + exp_of_linear_predictor.array ());
150+ }
151+ else if (family==" poisson" )
152+ return linear_predictor.array ().exp ();
153+ else if (family==" gamma" )
154+ {
155+ VectorXd transformed_linear_predictor{transform_zero_to_negative (linear_predictor)};
156+ return -1 /transformed_linear_predictor.array ();
157+ }
158+ else if (family==" poissongamma" )
159+ {
160+ VectorXd transformed_linear_predictor{transform_zero_to_negative (linear_predictor)};
161+ return transformed_linear_predictor.array ().pow (-2 ).array () * 4.0 ;
162+ }
163+ return VectorXd (0 );
164+ }
165+
166+ double transform_linear_predictor_to_prediction (double linear_predictor, const std::string &family=" gaussian" )
167+ {
168+ if (family==" gaussian" )
169+ return linear_predictor;
170+ else if (family==" logit" )
171+ {
172+ double exp_of_linear_predictor{std::exp (linear_predictor)};
173+ return exp_of_linear_predictor / (1.0 + exp_of_linear_predictor);
174+ }
175+ else if (family==" poisson" )
176+ return std::exp (linear_predictor);
177+ else if (family==" gamma" )
178+ {
179+ double negative_linear_predictor{std::min (linear_predictor,SMALL_NEGATIVE_VALUE)};
180+ return -1 /negative_linear_predictor;
181+ }
182+ else if (family==" poissongamma" )
183+ {
184+ double negative_linear_predictor{std::min (linear_predictor,SMALL_NEGATIVE_VALUE)};
185+ return 4.0 * std::pow (negative_linear_predictor,-2 );
186+ }
187+ return NAN_DOUBLE;
188+ }
189+
91190// sorts index based on v
92191VectorXi sort_indexes_ascending (const VectorXd &v)
93192{
0 commit comments