@@ -40,38 +40,77 @@ static bool check_if_approximately_zero(TReal a, TReal tolerance = std::numeric_
4040    return  false ;
4141}
4242
43+ VectorXd calculate_gaussian_errors (const  VectorXd &y,const  VectorXd &predicted)
44+ {
45+     VectorXd errors{y-predicted};
46+     errors=errors.array ()*errors.array ();
47+     return  errors;
48+ }
49+ 
50+ VectorXd calculate_logit_errors (const  VectorXd &y,const  VectorXd &predicted)
51+ {
52+     VectorXd errors{-y.array () * predicted.array ().log ()  -  (1.0 -y.array ()).array () * (1.0 -predicted.array ()).log ()};
53+     return  errors;
54+ }
55+ 
56+ VectorXd calculate_poisson_errors (const  VectorXd &y,const  VectorXd &predicted)
57+ {
58+     VectorXd errors{predicted.array () - y.array ()*predicted.array ().log ()};
59+     return  errors;
60+ }
61+ 
62+ VectorXd calculate_gamma_errors (const  VectorXd &y,const  VectorXd &predicted)
63+ {
64+     VectorXd errors{predicted.array ().log () - y.array ().log () + y.array ()/predicted.array ()-1 };
65+     return  errors;
66+ }
67+ 
68+ VectorXd calculate_poissongamma_errors (const  VectorXd &y,const  VectorXd &predicted)
69+ {
70+     VectorXd errors{y.array ().pow (0.5 ).array () / (-0.25 ) + y.array ()*predicted.array ().pow (-0.5 ) / 0.5  + predicted.array ().pow (0.5 ) / 0.5 };
71+     return  errors;
72+ }
73+ 
4374// Computes errors (for each observation) based on error metric for a vector
44- VectorXd calculate_errors (const  VectorXd &y,const  VectorXd &predicted,const  VectorXd &sample_weight=VectorXd(0 ),bool loss_function_mse=true )
75+ VectorXd calculate_errors (const  VectorXd &y,const  VectorXd &predicted,const  VectorXd &sample_weight=VectorXd(0 ),const  std::string &family="gaussian" )
4576{   
4677    // Error per observation before adjustment for sample weights
47-     VectorXd residuals{y-predicted};
48-     if (loss_function_mse)
49-         residuals=residuals.array ()*residuals.array ();
50-     else 
51-         residuals=residuals.cwiseAbs ();
52- 
78+     VectorXd errors;
79+     if (family==" gaussian"  )
80+         errors=calculate_gaussian_errors (y,predicted);
81+     else  if (family==" logit"  )
82+         errors=calculate_logit_errors (y,predicted);
83+     else  if (family==" poisson"  )
84+         errors=calculate_poisson_errors (y,predicted);
85+     else  if (family==" gamma"  )
86+         errors=calculate_gamma_errors (y,predicted);
87+     else  if (family==" poissongamma"  )
88+         errors=calculate_poissongamma_errors (y,predicted);
5389    // Adjusting for sample weights if specified
5490    if (sample_weight.size ()>0 )
55-         residuals=residuals .array ()*sample_weight.array ();
91+         errors=errors .array ()*sample_weight.array ();
5692
57-     return  residuals;
93+     return  errors;
94+ }
95+ 
96+ double  calculate_gaussian_error_one_observation (double  y,double  predicted)
97+ {
98+     double  error{y-predicted};
99+     error=error*error;
100+     return  error;
58101}
59102
60103// Computes error for one observation based on error metric
61- double  calculate_error_one_observation (double  y,double  predicted,double  sample_weight=NAN_DOUBLE, bool  loss_function_mse= true )
104+ double  calculate_error_one_observation (double  y,double  predicted,double  sample_weight=NAN_DOUBLE)
62105{   
63106    // Error per observation before adjustment for sample weights
64-     double  residual{y-predicted};
65-     if (loss_function_mse)
66-         residual=residual*residual;
67-     else 
68-         residual=abs (residual);
69- 
107+     double  error{calculate_gaussian_error_one_observation (y,predicted)};    
108+     
70109    // Adjusting for sample weights if specified
71110    if (!std::isnan (sample_weight))
72-         residual=residual *sample_weight;
73-      
74-     return  residual ;
111+         error=error *sample_weight;
112+ 
113+     return  error ;
75114}
76115
77116// Computes overall error based on errors from calculate_errors(), returning one value
@@ -88,6 +127,66 @@ double calculate_error(const VectorXd &errors,const VectorXd &sample_weight=Vect
88127    return  error;
89128}
90129
130+ VectorXd transform_zero_to_negative (const  VectorXd &linear_predictor)
131+ {
132+     VectorXd transformed_linear_predictor{linear_predictor};
133+     for  (size_t  i = 0 ; i < static_cast <size_t >(transformed_linear_predictor.rows ()); ++i)
134+     {
135+         bool  row_is_not_negative{std::isgreaterequal (transformed_linear_predictor[i],0.0 )};
136+         if (row_is_not_negative)
137+             transformed_linear_predictor[i]=SMALL_NEGATIVE_VALUE;
138+     }
139+     return  transformed_linear_predictor;
140+ }
141+ 
142+ VectorXd transform_linear_predictor_to_predictions (const  VectorXd &linear_predictor, const  std::string &family=" gaussian"  )
143+ {
144+     if (family==" gaussian"  )
145+         return  linear_predictor;
146+     else  if (family==" logit"  )
147+     {
148+         VectorXd exp_of_linear_predictor{linear_predictor.array ().exp ()};
149+         return  exp_of_linear_predictor.array () / (1.0  + exp_of_linear_predictor.array ());
150+     }
151+     else  if (family==" poisson"  )
152+         return  linear_predictor.array ().exp ();
153+     else  if (family==" gamma"  )
154+     {
155+         VectorXd transformed_linear_predictor{transform_zero_to_negative (linear_predictor)};
156+         return  -1 /transformed_linear_predictor.array ();
157+     }
158+     else  if (family==" poissongamma"  )
159+     {
160+         VectorXd transformed_linear_predictor{transform_zero_to_negative (linear_predictor)};
161+         return  transformed_linear_predictor.array ().pow (-2 ).array () * 4.0 ;
162+     }
163+     return  VectorXd (0 );
164+ }
165+ 
166+ double  transform_linear_predictor_to_prediction (double  linear_predictor, const  std::string &family=" gaussian"  )
167+ {
168+     if (family==" gaussian"  )
169+         return  linear_predictor;
170+     else  if (family==" logit"  )
171+     {
172+         double  exp_of_linear_predictor{std::exp (linear_predictor)};
173+         return  exp_of_linear_predictor / (1.0  + exp_of_linear_predictor);
174+     }
175+     else  if (family==" poisson"  )
176+         return  std::exp (linear_predictor);
177+     else  if (family==" gamma"  )
178+     {
179+         double  negative_linear_predictor{std::min (linear_predictor,SMALL_NEGATIVE_VALUE)};
180+         return  -1 /negative_linear_predictor;
181+     }
182+     else  if (family==" poissongamma"  )
183+     {
184+         double  negative_linear_predictor{std::min (linear_predictor,SMALL_NEGATIVE_VALUE)};
185+         return  4.0  * std::pow (negative_linear_predictor,-2 );
186+     }
187+     return  NAN_DOUBLE;
188+ }
189+ 
91190// sorts index based on v
92191VectorXi sort_indexes_ascending (const  VectorXd &v)
93192{
0 commit comments