Skip to content

Commit 021d88d

Browse files
wip
1 parent 745fadb commit 021d88d

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

cpp/APLRRegressor.h

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class APLRRegressor
6262
void estimate_split_points_for_interactions_to_consider();
6363
void sort_errors_for_interactions_to_consider();
6464
void add_promising_interactions_and_select_the_best_one();
65-
void update_intercept(size_t boosting_step);
65+
void update_intercept();
6666
void select_the_best_term_and_update_errors(size_t boosting_step);
6767
void update_terms(size_t boosting_step);
6868
void update_gradient_and_errors();
@@ -505,6 +505,7 @@ VectorXd APLRRegressor::differentiate_predictions()
505505
void APLRRegressor::execute_boosting_steps()
506506
{
507507
abort_boosting = false;
508+
update_intercept();
508509
for (size_t boosting_step = 0; boosting_step < m; ++boosting_step)
509510
{
510511
execute_boosting_step(boosting_step);
@@ -514,7 +515,6 @@ void APLRRegressor::execute_boosting_steps()
514515

515516
void APLRRegressor::execute_boosting_step(size_t boosting_step)
516517
{
517-
update_intercept(boosting_step);
518518
if(!abort_boosting)
519519
{
520520
find_best_split_for_each_eligible_term();
@@ -526,23 +526,18 @@ void APLRRegressor::execute_boosting_step(size_t boosting_step)
526526
print_summary_after_boosting_step(boosting_step);
527527
}
528528

529-
void APLRRegressor::update_intercept(size_t boosting_step)
529+
void APLRRegressor::update_intercept()
530530
{
531-
double intercept_update;
532531
if(sample_weight_train.size()==0)
533-
intercept_update=v*neg_gradient_current.mean();
532+
intercept=neg_gradient_current.mean();
534533
else
535-
intercept_update=v*(neg_gradient_current.array()*sample_weight_train.array()).sum()/sample_weight_train.array().sum();
536-
linear_predictor_update=VectorXd::Constant(neg_gradient_current.size(),intercept_update);
537-
linear_predictor_update_validation=VectorXd::Constant(y_validation.size(),intercept_update);
534+
intercept=(neg_gradient_current.array()*sample_weight_train.array()).sum()/sample_weight_train.array().sum();
535+
intercept_steps=VectorXd::Constant(m,intercept);
536+
linear_predictor_update=VectorXd::Constant(neg_gradient_current.size(),intercept);
537+
linear_predictor_update_validation=VectorXd::Constant(y_validation.size(),intercept);
538538
update_linear_predictor_and_predictors();
539539
update_gradient_and_errors();
540-
calculate_and_validate_validation_error(boosting_step);
541-
if(!abort_boosting)
542-
{
543-
intercept+=intercept_update;
544-
intercept_steps[boosting_step]=intercept;
545-
}
540+
calculate_and_validate_validation_error(0);
546541
}
547542

548543
void APLRRegressor::update_linear_predictor_and_predictors()

0 commit comments

Comments
 (0)