Skip to content

Commit bd9ab1a

Browse files
Merge pull request #11 from ottenbreit-data-science/glm
added possibility to train various kinds of glm models
2 parents 16dc31e + f05cc70 commit bd9ab1a

15 files changed

+496
-111
lines changed

aplr/aplr.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66

77
class APLRRegressor():
8-
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, loss_function_mse:bool=True, n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=100, max_interactions:int=0, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0):
8+
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=100, max_interactions:int=0, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0):
99
self.m=m
1010
self.v=v
1111
self.random_state=random_state
12-
self.loss_function_mse=loss_function_mse
12+
self.family=family
1313
self.n_jobs=n_jobs
1414
self.validation_ratio=validation_ratio
1515
self.intercept=intercept
@@ -30,7 +30,7 @@ def __set_params_cpp(self):
3030
self.APLRRegressor.m=self.m
3131
self.APLRRegressor.v=self.v
3232
self.APLRRegressor.random_state=self.random_state
33-
self.APLRRegressor.loss_function_mse=self.loss_function_mse
33+
self.APLRRegressor.family=self.family
3434
self.APLRRegressor.n_jobs=self.n_jobs
3535
self.APLRRegressor.validation_ratio=self.validation_ratio
3636
self.APLRRegressor.intercept=self.intercept
@@ -87,7 +87,7 @@ def get_m(self)->int:
8787

8888
#For sklearn
8989
def get_params(self, deep=True):
90-
return {"m": self.m, "v": self.v,"random_state":self.random_state,"loss_function_mse":self.loss_function_mse,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms}
90+
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms}
9191

9292
#For sklearn
9393
def set_params(self, **parameters):

cpp/APLRRegressor.h

Lines changed: 85 additions & 40 deletions
Large diffs are not rendered by default.

cpp/constants.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
22
#include <limits>
33

4-
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
4+
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
5+
const double SMALL_NEGATIVE_VALUE{-0.001};

cpp/functions.h

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,38 +40,77 @@ static bool check_if_approximately_zero(TReal a, TReal tolerance = std::numeric_
4040
return false;
4141
}
4242

43+
VectorXd calculate_gaussian_errors(const VectorXd &y,const VectorXd &predicted)
44+
{
45+
VectorXd errors{y-predicted};
46+
errors=errors.array()*errors.array();
47+
return errors;
48+
}
49+
50+
VectorXd calculate_logit_errors(const VectorXd &y,const VectorXd &predicted)
51+
{
52+
VectorXd errors{-y.array() * predicted.array().log() - (1.0-y.array()).array() * (1.0-predicted.array()).log()};
53+
return errors;
54+
}
55+
56+
VectorXd calculate_poisson_errors(const VectorXd &y,const VectorXd &predicted)
57+
{
58+
VectorXd errors{predicted.array() - y.array()*predicted.array().log()};
59+
return errors;
60+
}
61+
62+
VectorXd calculate_gamma_errors(const VectorXd &y,const VectorXd &predicted)
63+
{
64+
VectorXd errors{predicted.array().log() - y.array().log() + y.array()/predicted.array()-1};
65+
return errors;
66+
}
67+
68+
VectorXd calculate_poissongamma_errors(const VectorXd &y,const VectorXd &predicted)
69+
{
70+
VectorXd errors{y.array().pow(0.5).array() / (-0.25) + y.array()*predicted.array().pow(-0.5) / 0.5 + predicted.array().pow(0.5) / 0.5};
71+
return errors;
72+
}
73+
4374
//Computes errors (for each observation) based on error metric for a vector
44-
VectorXd calculate_errors(const VectorXd &y,const VectorXd &predicted,const VectorXd &sample_weight=VectorXd(0),bool loss_function_mse=true)
75+
VectorXd calculate_errors(const VectorXd &y,const VectorXd &predicted,const VectorXd &sample_weight=VectorXd(0),const std::string &family="gaussian")
4576
{
4677
//Error per observation before adjustment for sample weights
47-
VectorXd residuals{y-predicted};
48-
if(loss_function_mse)
49-
residuals=residuals.array()*residuals.array();
50-
else
51-
residuals=residuals.cwiseAbs();
52-
78+
VectorXd errors;
79+
if(family=="gaussian")
80+
errors=calculate_gaussian_errors(y,predicted);
81+
else if(family=="logit")
82+
errors=calculate_logit_errors(y,predicted);
83+
else if(family=="poisson")
84+
errors=calculate_poisson_errors(y,predicted);
85+
else if(family=="gamma")
86+
errors=calculate_gamma_errors(y,predicted);
87+
else if(family=="poissongamma")
88+
errors=calculate_poissongamma_errors(y,predicted);
5389
//Adjusting for sample weights if specified
5490
if(sample_weight.size()>0)
55-
residuals=residuals.array()*sample_weight.array();
91+
errors=errors.array()*sample_weight.array();
5692

57-
return residuals;
93+
return errors;
94+
}
95+
96+
double calculate_gaussian_error_one_observation(double y,double predicted)
97+
{
98+
double error{y-predicted};
99+
error=error*error;
100+
return error;
58101
}
59102

60103
//Computes error for one observation based on error metric
61-
double calculate_error_one_observation(double y,double predicted,double sample_weight=NAN_DOUBLE,bool loss_function_mse=true)
104+
double calculate_error_one_observation(double y,double predicted,double sample_weight=NAN_DOUBLE)
62105
{
63106
//Error per observation before adjustment for sample weights
64-
double residual{y-predicted};
65-
if(loss_function_mse)
66-
residual=residual*residual;
67-
else
68-
residual=abs(residual);
69-
107+
double error{calculate_gaussian_error_one_observation(y,predicted)};
108+
70109
//Adjusting for sample weights if specified
71110
if(!std::isnan(sample_weight))
72-
residual=residual*sample_weight;
73-
74-
return residual;
111+
error=error*sample_weight;
112+
113+
return error;
75114
}
76115

77116
//Computes overall error based on errors from calculate_errors(), returning one value
@@ -88,6 +127,66 @@ double calculate_error(const VectorXd &errors,const VectorXd &sample_weight=Vect
88127
return error;
89128
}
90129

130+
VectorXd transform_zero_to_negative(const VectorXd &linear_predictor)
131+
{
132+
VectorXd transformed_linear_predictor{linear_predictor};
133+
for (size_t i = 0; i < static_cast<size_t>(transformed_linear_predictor.rows()); ++i)
134+
{
135+
bool row_is_not_negative{std::isgreaterequal(transformed_linear_predictor[i],0.0)};
136+
if(row_is_not_negative)
137+
transformed_linear_predictor[i]=SMALL_NEGATIVE_VALUE;
138+
}
139+
return transformed_linear_predictor;
140+
}
141+
142+
VectorXd transform_linear_predictor_to_predictions(const VectorXd &linear_predictor, const std::string &family="gaussian")
143+
{
144+
if(family=="gaussian")
145+
return linear_predictor;
146+
else if(family=="logit")
147+
{
148+
VectorXd exp_of_linear_predictor{linear_predictor.array().exp()};
149+
return exp_of_linear_predictor.array() / (1.0 + exp_of_linear_predictor.array());
150+
}
151+
else if(family=="poisson")
152+
return linear_predictor.array().exp();
153+
else if(family=="gamma")
154+
{
155+
VectorXd transformed_linear_predictor{transform_zero_to_negative(linear_predictor)};
156+
return -1/transformed_linear_predictor.array();
157+
}
158+
else if(family=="poissongamma")
159+
{
160+
VectorXd transformed_linear_predictor{transform_zero_to_negative(linear_predictor)};
161+
return transformed_linear_predictor.array().pow(-2).array() * 4.0;
162+
}
163+
return VectorXd(0);
164+
}
165+
166+
double transform_linear_predictor_to_prediction(double linear_predictor, const std::string &family="gaussian")
167+
{
168+
if(family=="gaussian")
169+
return linear_predictor;
170+
else if(family=="logit")
171+
{
172+
double exp_of_linear_predictor{std::exp(linear_predictor)};
173+
return exp_of_linear_predictor / (1.0 + exp_of_linear_predictor);
174+
}
175+
else if(family=="poisson")
176+
return std::exp(linear_predictor);
177+
else if(family=="gamma")
178+
{
179+
double negative_linear_predictor{std::min(linear_predictor,SMALL_NEGATIVE_VALUE)};
180+
return -1/negative_linear_predictor;
181+
}
182+
else if(family=="poissongamma")
183+
{
184+
double negative_linear_predictor{std::min(linear_predictor,SMALL_NEGATIVE_VALUE)};
185+
return 4.0 * std::pow(negative_linear_predictor,-2);
186+
}
187+
return NAN_DOUBLE;
188+
}
189+
91190
//sorts index based on v
92191
VectorXi sort_indexes_ascending(const VectorXd &v)
93192
{

cpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int main()
1616
model.v=0.5;
1717
model.bins=300;
1818
model.n_jobs=0;
19-
model.loss_function_mse=true;
19+
model.family="gaussian";
2020
model.verbosity=3;
2121
model.min_observations_in_split=10;
2222
//model.max_interaction_level=0;

cpp/pythonbinding.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ namespace py = pybind11;
1111

1212
PYBIND11_MODULE(aplr_cpp, m) {
1313
py::class_<APLRRegressor>(m, "APLRRegressor",py::module_local())
14-
.def(py::init<int&, double&, int&, bool&,int&,double&,double&,int&,int&,int&,int&,int&,int&,int&,int&>(),
15-
py::arg("m")=1000,py::arg("v")=0.1,py::arg("random_state")=0,py::arg("loss_function_mse")=true,
14+
.def(py::init<int&, double&, int&, std::string&,int&,double&,double&,int&,int&,int&,int&,int&,int&,int&,int&>(),
15+
py::arg("m")=1000,py::arg("v")=0.1,py::arg("random_state")=0,py::arg("family")="gaussian",
1616
py::arg("n_jobs")=0,py::arg("validation_ratio")=0.2,py::arg("intercept")=NAN_DOUBLE,
1717
py::arg("reserved_terms_times_num_x")=100,py::arg("bins")=300,py::arg("verbosity")=0,
1818
py::arg("max_interaction_level")=100,py::arg("max_interactions")=0,py::arg("min_observations_in_split")=20,
@@ -40,7 +40,7 @@ PYBIND11_MODULE(aplr_cpp, m) {
4040
.def_readwrite("max_interactions", &APLRRegressor::max_interactions)
4141
.def_readwrite("min_observations_in_split", &APLRRegressor::min_observations_in_split)
4242
.def_readwrite("interactions_eligible", &APLRRegressor::interactions_eligible)
43-
.def_readwrite("loss_function_mse", &APLRRegressor::loss_function_mse)
43+
.def_readwrite("family", &APLRRegressor::family)
4444
.def_readwrite("validation_ratio", &APLRRegressor::validation_ratio)
4545
.def_readwrite("validation_error_steps", &APLRRegressor::validation_error_steps)
4646
.def_readwrite("n_jobs", &APLRRegressor::n_jobs)
@@ -57,7 +57,7 @@ PYBIND11_MODULE(aplr_cpp, m) {
5757
.def(py::pickle(
5858
[](const APLRRegressor &a) { // __getstate__
5959
/* Return a tuple that fully encodes the state of the object */
60-
return py::make_tuple(a.m,a.v,a.random_state,a.loss_function_mse,a.n_jobs,a.validation_ratio,a.intercept,a.bins,a.verbosity,
60+
return py::make_tuple(a.m,a.v,a.random_state,a.family,a.n_jobs,a.validation_ratio,a.intercept,a.bins,a.verbosity,
6161
a.max_interaction_level,a.max_interactions,a.validation_error_steps,a.term_names,a.term_coefficients,a.terms,a.intercept_steps,
6262
a.interactions_eligible,a.min_observations_in_split,a.ineligible_boosting_steps_added,a.max_eligible_terms,
6363
a.number_of_base_terms,a.feature_importance);
@@ -67,7 +67,7 @@ PYBIND11_MODULE(aplr_cpp, m) {
6767
throw std::runtime_error("Invalid state!");
6868

6969
/* Create a new C++ instance */
70-
APLRRegressor a(t[0].cast<size_t>(),t[1].cast<double>(),t[2].cast<uint_fast32_t>(),t[3].cast<bool>(),t[4].cast<size_t>(),t[5].cast<double>(),
70+
APLRRegressor a(t[0].cast<size_t>(),t[1].cast<double>(),t[2].cast<uint_fast32_t>(),t[3].cast<std::string>(),t[4].cast<size_t>(),t[5].cast<double>(),
7171
t[6].cast<double>(),100,t[7].cast<size_t>(),t[8].cast<size_t>(),t[9].cast<size_t>(),t[10].cast<double>(),t[17].cast<size_t>());
7272

7373
a.validation_error_steps=t[11].cast<VectorXd>();

0 commit comments

Comments
 (0)