@@ -77,6 +77,7 @@ class APLRRegressor
7777    void  cleanup_after_fit ();
7878    void  validate_that_model_can_be_used (const  MatrixXd &X);
7979    void  throw_error_if_family_does_not_exist ();
80+     void  throw_error_if_link_function_does_not_exist ();
8081    VectorXd calculate_linear_predictor (const  MatrixXd &X);
8182    void  update_linear_predictor_and_predictors ();
8283    void  throw_error_if_response_contains_invalid_values (const  VectorXd &y);
@@ -91,6 +92,7 @@ class APLRRegressor
9192    size_t  m; // Boosting steps to run. Can shrink to auto tuned value after running fit().
9293    double  v; // Learning rate.
9394    std::string family;
95+     std::string link_function;
9496    double  validation_ratio;
9597    size_t  n_jobs; // 0:using all available cores. 1:no multithreading. >1: Using a specified number of cores but not more than is available.
9698    uint_fast32_t  random_state; // For train/validation split. If std::numeric_limits<uint_fast32_t>::lowest() then will randomly set a seed
@@ -112,10 +114,10 @@ class APLRRegressor
112114    VectorXd feature_importance; // Populated in fit() using validation set. Rows are in the same order as in X.
113115
114116    // Methods
115-     APLRRegressor (size_t  m=1000 ,double  v=0.1 ,uint_fast32_t  random_state=std::numeric_limits<uint_fast32_t >::lowest(),std::string family=" gaussian"  ,size_t  n_jobs= 0 , 
116-         double  validation_ratio= 0.2 , double  intercept=NAN_DOUBLE, size_t  reserved_terms_times_num_x= 100 , size_t  bins= 300 , size_t  verbosity= 0 ,
117-         size_t  max_interaction_level =100 ,size_t  max_interactions= 0 ,size_t  min_observations_in_split= 20 ,size_t  ineligible_boosting_steps_added= 10 ,
118-         size_t  max_eligible_terms=5 );
117+     APLRRegressor (size_t  m=1000 ,double  v=0.1 ,uint_fast32_t  random_state=std::numeric_limits<uint_fast32_t >::lowest(),std::string family=" gaussian"  ,
118+         std::string link_function= " identity " ,  size_t  n_jobs= 0 ,  double  validation_ratio= 0.2 , double  intercept=NAN_DOUBLE ,
119+         size_t  reserved_terms_times_num_x =100 ,  size_t  bins= 300 , size_t  verbosity= 0 ,size_t  max_interaction_level= 100 ,size_t  max_interactions= 0 ,
120+         size_t  min_observations_in_split= 20 ,  size_t  ineligible_boosting_steps_added= 10 ,  size_t   max_eligible_terms=5 );
119121    APLRRegressor (const  APLRRegressor &other);
120122    ~APLRRegressor ();
121123    void  fit (const  MatrixXd &X,const  VectorXd &y,const  VectorXd &sample_weight=VectorXd(0 ),const  std::vector<std::string> &X_names={},const  std::vector<size_t > &validation_set_indexes={});
@@ -135,11 +137,11 @@ class APLRRegressor
135137};
136138
137139// Regular constructor
138- APLRRegressor::APLRRegressor (size_t  m,double  v,uint_fast32_t  random_state,std::string family,size_t  n_jobs, double  validation_ratio, double  intercept ,
139-     size_t  reserved_terms_times_num_x, size_t  bins ,size_t  verbosity ,size_t  max_interaction_level ,size_t  max_interactions ,size_t  min_observations_in_split ,
140-     size_t  ineligible_boosting_steps_added,size_t  max_eligible_terms):
140+ APLRRegressor::APLRRegressor (size_t  m,double  v,uint_fast32_t  random_state,std::string family,std::string link_function, size_t  n_jobs ,
141+     double  validation_ratio, double  intercept ,size_t  reserved_terms_times_num_x ,size_t  bins ,size_t  verbosity ,size_t  max_interaction_level ,
142+     size_t  max_interactions, size_t  min_observations_in_split, size_t   ineligible_boosting_steps_added,size_t  max_eligible_terms):
141143        reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
142-         family{family},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
144+         family{family},link_function{link_function}, validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
143145        bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
144146        intercept_steps{VectorXd (0 )},max_interactions{max_interactions},interactions_eligible{0 },validation_error_steps{VectorXd (0 )},
145147        min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
@@ -150,7 +152,7 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
150152// Copy constructor
151153APLRRegressor::APLRRegressor (const  APLRRegressor &other):
152154    reserved_terms_times_num_x{other.reserved_terms_times_num_x },intercept{other.intercept },terms{other.terms },m{other.m },v{other.v },
153-     family{other.family },validation_ratio{other.validation_ratio },
155+     family{other.family },link_function{other. link_function }, validation_ratio{other.validation_ratio },
154156    n_jobs{other.n_jobs },random_state{other.random_state },bins{other.bins },
155157    verbosity{other.verbosity },term_names{other.term_names },term_coefficients{other.term_coefficients },
156158    max_interaction_level{other.max_interaction_level },intercept_steps{other.intercept_steps },
@@ -173,6 +175,7 @@ APLRRegressor::~APLRRegressor()
173175void  APLRRegressor::fit (const  MatrixXd &X,const  VectorXd &y,const  VectorXd &sample_weight,const  std::vector<std::string> &X_names,const  std::vector<size_t > &validation_set_indexes)
174176{
175177    throw_error_if_family_does_not_exist ();
178+     throw_error_if_link_function_does_not_exist ();
176179    validate_input_to_fit (X,y,sample_weight,X_names,validation_set_indexes);
177180    define_training_and_validation_sets (X,y,sample_weight,validation_set_indexes);
178181    initialize ();
@@ -212,11 +215,11 @@ void APLRRegressor::throw_error_if_validation_set_indexes_has_invalid_indexes(co
212215
213216void  APLRRegressor::throw_error_if_response_contains_invalid_values (const  VectorXd &y)
214217{
215-     if (family ==" logit"  )
218+     if (link_function ==" logit"  )
216219        throw_error_if_response_is_not_between_0_and_1 (y);
217-     else  if (family ==" poisson "   || family ==" poissongamma "  )
220+     else  if (link_function ==" log "   || link_function ==" inverseroot "  )
218221        throw_error_if_response_is_negative (y);
219-     else  if (family ==" gamma "   || family ==" inversegaussian "  )
222+     else  if (link_function ==" inverse "   || link_function ==" inversesquare "  )
220223        throw_error_if_response_is_not_greater_than_zero (y);
221224}
222225
@@ -225,21 +228,21 @@ void APLRRegressor::throw_error_if_response_is_not_between_0_and_1(const VectorX
225228    bool  response_is_less_than_zero{(y.array ()<0.0 ).any ()};
226229    bool  response_is_greater_than_one{(y.array ()>1.0 ).any ()};
227230    if (response_is_less_than_zero || response_is_greater_than_one)
228-         throw  std::runtime_error (" Response values for "  +family +"  models  cannot be less than zero or greater than one."  );   
231+         throw  std::runtime_error (" Response values for "  +link_function +"  link functions  cannot be less than zero or greater than one."  );   
229232}
230233
231234void  APLRRegressor::throw_error_if_response_is_negative (const  VectorXd &y)
232235{
233236    bool  response_is_less_than_zero{(y.array ()<0.0 ).any ()};
234237    if (response_is_less_than_zero)
235-         throw  std::runtime_error (" Response values for "  +family +"  models  cannot be less than zero."  );   
238+         throw  std::runtime_error (" Response values for "  +link_function +"  link functions  cannot be less than zero."  );   
236239}
237240
238241void  APLRRegressor::throw_error_if_response_is_not_greater_than_zero (const  VectorXd &y)
239242{
240243    bool  response_is_not_greater_than_zero{(y.array ()<=0.0 ).any ()};
241244    if (response_is_not_greater_than_zero)
242-         throw  std::runtime_error (" Response values for "  +family +"  models  must be greater than zero."  );   
245+         throw  std::runtime_error (" Response values for "  +link_function +"  link functions  must be greater than zero."  );   
243246
244247}
245248
@@ -337,11 +340,11 @@ void APLRRegressor::initialize()
337340        }
338341    }
339342
340-     linear_predictor_current=VectorXd::Constant (y_train.size (),0 );
343+     linear_predictor_current=VectorXd::Constant (y_train.size (),intercept );
341344    linear_predictor_null_model=linear_predictor_current;
342-     linear_predictor_current_validation=VectorXd::Constant (y_validation.size (),0 );
343-     predictions_current=transform_linear_predictor_to_predictions (linear_predictor_current,family );
344-     predictions_current_validation=transform_linear_predictor_to_predictions (linear_predictor_current_validation,family );
345+     linear_predictor_current_validation=VectorXd::Constant (y_validation.size (),intercept );
346+     predictions_current=transform_linear_predictor_to_predictions (linear_predictor_current,link_function );
347+     predictions_current_validation=transform_linear_predictor_to_predictions (linear_predictor_current_validation,link_function );
345348
346349    validation_error_steps.resize (m);
347350    validation_error_steps.setConstant (std::numeric_limits<double >::infinity ());
@@ -379,14 +382,14 @@ VectorXd APLRRegressor::calculate_neg_gradient_current(const VectorXd &y,const V
379382    VectorXd output;
380383    if (family==" gaussian"  )
381384        output=y-predictions_current;
382-     else  if (family==" logit "  )
385+     else  if (family==" binomial "  )
383386        output=y.array () / predictions_current.array () - (y.array ()-1.0 ) / (predictions_current.array ()-1.0 );
384387    else  if (family==" poisson"  )
385388        output=y.array () / predictions_current.array () - 1 ;
386389    else  if (family==" gamma"  )
387390        output=(y.array () - predictions_current.array ()) / predictions_current.array () / predictions_current.array ();
388391    else  if (family==" poissongamma"  )
389-         output=( y.array () / predictions_current.array ().pow (1.5 ) - predictions_current.array ().pow (-0.5 ) );
392+         output=y.array () / predictions_current.array ().pow (1.5 ) - predictions_current.array ().pow (-0.5 );
390393    else  if (family==" inversegaussian"  )
391394        output=y.array () / predictions_current.array ().pow (3.0 ) - predictions_current.array ().pow (-2.0 );
392395    return  output;
@@ -692,8 +695,9 @@ void APLRRegressor::select_the_best_term_and_update_errors(size_t boosting_step)
692695    if (validation_error_is_invalid)
693696    {
694697        abort_boosting=true ;
695-         std::string warning_message{" Warning: Encountered numerical problems when calculating prediction errors."  };
696-         if (family==" poisson"   || family==" poissongamma"   ||family==" gamma"   || family==" inversegaussian"  )
698+         std::string warning_message{" Warning: Encountered numerical problems when calculating prediction errors in the previous boosting step. Not continuing with further boosting steps."  };
699+         bool  show_additional_warning{family==" poisson"   || family==" poissongamma"   || family==" gamma"   || family==" inversegaussian"   || (link_function!=" identity"   && link_function!=" logit"  )};
700+         if (show_additional_warning)
697701            warning_message+="  A reason may be too large response values."  ;
698702        std::cout<<warning_message<<" \n "  ;
699703    }
@@ -703,8 +707,8 @@ void APLRRegressor::update_linear_predictor_and_predictors()
703707{
704708    linear_predictor_current+=linear_predictor_update;
705709    linear_predictor_current_validation+=linear_predictor_update_validation;
706-     predictions_current=transform_linear_predictor_to_predictions (linear_predictor_current,family );
707-     predictions_current_validation=transform_linear_predictor_to_predictions (linear_predictor_current_validation,family );
710+     predictions_current=transform_linear_predictor_to_predictions (linear_predictor_current,link_function );
711+     predictions_current_validation=transform_linear_predictor_to_predictions (linear_predictor_current_validation,link_function );
708712}
709713
710714void  APLRRegressor::update_gradient_and_errors ()
@@ -960,7 +964,7 @@ VectorXd APLRRegressor::predict(const MatrixXd &X)
960964    validate_that_model_can_be_used (X);
961965
962966    VectorXd linear_predictor{calculate_linear_predictor (X)};
963-     VectorXd predictions{transform_linear_predictor_to_predictions (linear_predictor,family )};
967+     VectorXd predictions{transform_linear_predictor_to_predictions (linear_predictor,link_function )};
964968
965969    return  predictions;
966970}
@@ -1053,7 +1057,7 @@ void APLRRegressor::throw_error_if_family_does_not_exist()
10531057    bool  family_exists{false };
10541058    if (family==" gaussian"  )
10551059        family_exists=true ;
1056-     else  if (family==" logit "  )
1060+     else  if (family==" binomial "  )
10571061        family_exists=true ;
10581062    else  if (family==" poisson"  )
10591063        family_exists=true ;
@@ -1065,4 +1069,23 @@ void APLRRegressor::throw_error_if_family_does_not_exist()
10651069        family_exists=true ;        
10661070    if (!family_exists)
10671071        throw  std::runtime_error (" Family "  +family+"  is not available in APLR."  );   
1072+ }
1073+ 
1074+ void  APLRRegressor::throw_error_if_link_function_does_not_exist ()
1075+ {
1076+     bool  link_function_exists{false };
1077+     if (link_function==" identity"  )
1078+         link_function_exists=true ;
1079+     else  if (link_function==" logit"  )
1080+         link_function_exists=true ;
1081+     else  if (link_function==" log"  )
1082+         link_function_exists=true ;
1083+     else  if (link_function==" inverseroot"  )
1084+         link_function_exists=true ;
1085+     else  if (link_function==" inverse"  )
1086+         link_function_exists=true ;        
1087+     else  if (link_function==" inversesquare"  )
1088+         link_function_exists=true ;        
1089+     if (!link_function_exists)
1090+         throw  std::runtime_error (" Link function "  +link_function+"  is not available in APLR."  );
10681091}
0 commit comments