Skip to content

Commit 745fadb

Browse files
reverted to 1.8.0
1 parent 8fef77c commit 745fadb

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

cpp/APLRRegressor.h

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ class APLRRegressor
5151
void throw_error_if_validation_set_indexes_has_invalid_indexes(const VectorXd &y, const std::vector<size_t> &validation_set_indexes);
5252
void define_training_and_validation_sets(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight, const std::vector<size_t> &validation_set_indexes);
5353
void initialize();
54-
void initialize_intercept();
5554
bool check_if_base_term_has_only_one_unique_value(size_t base_term);
5655
void add_term_to_terms_eligible_current(Term &term);
5756
VectorXd calculate_neg_gradient_current();
@@ -63,6 +62,7 @@ class APLRRegressor
6362
void estimate_split_points_for_interactions_to_consider();
6463
void sort_errors_for_interactions_to_consider();
6564
void add_promising_interactions_and_select_the_best_one();
65+
void update_intercept(size_t boosting_step);
6666
void select_the_best_term_and_update_errors(size_t boosting_step);
6767
void update_terms(size_t boosting_step);
6868
void update_gradient_and_errors();
@@ -419,7 +419,8 @@ void APLRRegressor::initialize()
419419
terms.reserve(X_train.cols()*reserved_terms_times_num_x);
420420
terms.clear();
421421

422-
initialize_intercept();
422+
intercept=0;
423+
intercept_steps=VectorXd::Constant(m,0);
423424

424425
terms_eligible_current.reserve(X_train.cols()*reserved_terms_times_num_x);
425426
for (size_t i = 0; i < static_cast<size_t>(X_train.cols()); ++i)
@@ -445,16 +446,6 @@ void APLRRegressor::initialize()
445446
update_gradient_and_errors();
446447
}
447448

448-
void APLRRegressor::initialize_intercept()
449-
{
450-
bool sample_weights_are_not_provided{sample_weight_train.size()==0};
451-
if(sample_weights_are_not_provided)
452-
intercept=y_train.mean();
453-
else
454-
intercept=(y_train.array()*sample_weight_train.array()).sum()/sample_weight_train.array().sum();
455-
intercept_steps=VectorXd::Constant(m,intercept);
456-
}
457-
458449
bool APLRRegressor::check_if_base_term_has_only_one_unique_value(size_t base_term)
459450
{
460451
size_t rows{static_cast<size_t>(X_train.rows())};
@@ -523,14 +514,37 @@ void APLRRegressor::execute_boosting_steps()
523514

524515
void APLRRegressor::execute_boosting_step(size_t boosting_step)
525516
{
526-
find_best_split_for_each_eligible_term();
527-
consider_interactions();
528-
select_the_best_term_and_update_errors(boosting_step);
517+
update_intercept(boosting_step);
518+
if(!abort_boosting)
519+
{
520+
find_best_split_for_each_eligible_term();
521+
consider_interactions();
522+
select_the_best_term_and_update_errors(boosting_step);
523+
}
529524
if(abort_boosting) return;
530525
update_term_eligibility();
531526
print_summary_after_boosting_step(boosting_step);
532527
}
533528

529+
void APLRRegressor::update_intercept(size_t boosting_step)
530+
{
531+
double intercept_update;
532+
if(sample_weight_train.size()==0)
533+
intercept_update=v*neg_gradient_current.mean();
534+
else
535+
intercept_update=v*(neg_gradient_current.array()*sample_weight_train.array()).sum()/sample_weight_train.array().sum();
536+
linear_predictor_update=VectorXd::Constant(neg_gradient_current.size(),intercept_update);
537+
linear_predictor_update_validation=VectorXd::Constant(y_validation.size(),intercept_update);
538+
update_linear_predictor_and_predictors();
539+
update_gradient_and_errors();
540+
calculate_and_validate_validation_error(boosting_step);
541+
if(!abort_boosting)
542+
{
543+
intercept+=intercept_update;
544+
intercept_steps[boosting_step]=intercept;
545+
}
546+
}
547+
534548
void APLRRegressor::update_linear_predictor_and_predictors()
535549
{
536550
linear_predictor_current+=linear_predictor_update;

0 commit comments

Comments
 (0)