Skip to content

Commit 95a54f0

Browse files
group mse fitting improved and small bugfix
1 parent 2808d1f commit 95a54f0

File tree

3 files changed

+28
-5
lines changed

3 files changed

+28
-5
lines changed

cpp/APLRRegressor.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class APLRRegressor
5050
std::set<int> unique_groups_train;
5151
std::set<int> unique_groups_validation;
5252
std::vector<int> interaction_constraints;
53+
bool update_intercept_only_once;
5354

5455
void validate_input_to_fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names,
5556
const std::vector<size_t> &validation_set_indexes, const std::vector<size_t> &prioritized_predictors_indexes,
@@ -109,6 +110,7 @@ class APLRRegressor
109110
void revert_scaling_if_using_log_link_function();
110111
void cap_predictions_to_minmax_in_training(VectorXd &predictions);
111112
std::string compute_raw_base_term_name(const Term &term, const std::string &X_name);
113+
void throw_error_if_m_is_invalid();
112114

113115
public:
114116
double intercept;
@@ -210,6 +212,7 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
210212
throw_error_if_loss_function_does_not_exist();
211213
throw_error_if_link_function_does_not_exist();
212214
throw_error_if_dispersion_parameter_is_invalid();
215+
throw_error_if_m_is_invalid();
213216
validate_input_to_fit(X,y,sample_weight,X_names,validation_set_indexes,prioritized_predictors_indexes,monotonic_constraints,group,
214217
interaction_constraints);
215218
define_training_and_validation_sets(X,y,sample_weight,validation_set_indexes,group);
@@ -543,8 +546,21 @@ void APLRRegressor::initialize(const std::vector<size_t> &prioritized_predictors
543546
terms.clear();
544547
terms.reserve(X_train.cols()*reserved_terms_times_num_x);
545548

546-
intercept=0;
547-
intercept_steps=VectorXd::Constant(m,0);
549+
if(loss_function == "group_mse")
550+
{
551+
update_intercept_only_once = true;
552+
if(sample_weight_train.size()==0)
553+
intercept = y_train.mean();
554+
else
555+
intercept = (y_train.array()*sample_weight_train.array()).sum()/sample_weight_train.array().sum();
556+
}
557+
else
558+
{
559+
update_intercept_only_once = false;
560+
intercept=0;
561+
}
562+
intercept_steps=VectorXd::Constant(m, intercept);
563+
548564

549565
terms_eligible_current.reserve(X_train.cols()*reserved_terms_times_num_x);
550566
size_t X_train_cols{static_cast<size_t>(X_train.cols())};
@@ -710,7 +726,8 @@ void APLRRegressor::execute_boosting_steps()
710726

711727
void APLRRegressor::execute_boosting_step(size_t boosting_step)
712728
{
713-
update_intercept(boosting_step);
729+
if(!update_intercept_only_once)
730+
update_intercept(boosting_step);
714731
bool prioritize_predictors{!abort_boosting && prioritized_predictors_indexes.size()>0};
715732
if(prioritize_predictors)
716733
{
@@ -1532,4 +1549,10 @@ std::string APLRRegressor::get_validation_tuning_metric()
15321549
std::vector<size_t> APLRRegressor::get_validation_indexes()
15331550
{
15341551
return validation_indexes;
1552+
}
1553+
1554+
void APLRRegressor::throw_error_if_m_is_invalid()
1555+
{
1556+
if(m<1)
1557+
throw std::runtime_error("The maximum number of boosting steps, m, must be at least 1.");
15351558
}

cpp/test ALRRegressor group_mse.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ int main()
5454
save_as_csv_file("data/output.csv",predictions);
5555

5656
std::cout<<predictions.mean()<<"\n\n";
57-
tests.push_back(is_approximately_equal(predictions.mean(),23.6919,0.00001));
57+
tests.push_back(is_approximately_equal(predictions.mean(),23.8587,0.00001));
5858

5959
//Test summary
6060
std::cout<<"\n\nTest summary\n"<<"Passed "<<std::accumulate(tests.begin(),tests.end(),0)<<" out of "<<tests.size()<<" tests.";

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setuptools.setup(
1717
name='aplr',
18-
version='4.0.0',
18+
version='4.1.0',
1919
description='Automatic Piecewise Linear Regression',
2020
ext_modules=[sfc_module],
2121
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)