Skip to content

Commit d807d15

Browse files
added the possibility to specify a custom loss function
1 parent 0758ac5 commit d807d15

8 files changed

+259
-19
lines changed

API_REFERENCE_FOR_REGRESSION.md

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# APLRRegressor
22

3-
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5, calculate_custom_validation_error_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None)
3+
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5, calculate_custom_validation_error_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None, calculate_custom_loss_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None, calculate_custom_negative_gradient_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], npt.ArrayLike]]=None)
44

55
### Constructor parameters
66

@@ -14,7 +14,7 @@ The learning rate. Must be greater than zero and not more than one. The higher t
1414
Used to randomly split training observations into training and validation if ***validation_set_indexes*** is not specified when fitting.
1515

1616
#### loss_function (default = "mse")
17-
Determines the loss function used. Allowed values are "mse", "binomial", "poisson", "gamma", "tweedie", "group_mse", "mae", "quantile", "negative_binomial", "cauchy" and "weibull". This is used together with ***link_function***. When ***loss_function*** is "group_mse" then the "group" argument in the ***fit*** method must be provided. In the latter case APLR will try to minimize group MSE when training the model. The ***loss_function*** "quantile" is used together with the ***quantile*** constructor parameter.
17+
Determines the loss function used. Allowed values are "mse", "binomial", "poisson", "gamma", "tweedie", "group_mse", "mae", "quantile", "negative_binomial", "cauchy", "weibull" and "custom_function". This is used together with ***link_function***. When ***loss_function*** is "group_mse" then the "group" argument in the ***fit*** method must be provided. In the latter case APLR will try to minimize group MSE when training the model. The ***loss_function*** "quantile" is used together with the ***quantile*** constructor parameter. When ***loss_function*** is "custom_function" then the constructor parameters ***calculate_custom_loss_function*** and ***calculate_custom_negative_gradient_function***, both described below, must be provided.
1818

1919
#### link_function (default = "identity")
2020
Determines how the linear predictor is transformed to predictions. Allowed values are "identity", "logit" and "log". For an ordinary regression model use ***loss_function*** "mse" and ***link_function*** "identity". For logistic regression use ***loss_function*** "binomial" and ***link_function*** "logit". For a multiplicative model use the "log" ***link_function***. The "log" ***link_function*** often works best with a "poisson", "gamma", "tweedie", "negative_binomial" or "weibull" ***loss_function***, depending on the data. The ***loss_function*** "poisson", "gamma", "tweedie", "negative_binomial" or "weibull" should only be used with the "log" ***link_function***. Inappropriate combinations of ***loss_function*** and ***link_function*** may result in a warning message when fitting the model and/or a poor model fit. Please note that values other than "identity" typically require a significantly higher ***m*** (or ***v***) in order to converge.
@@ -56,14 +56,32 @@ Specifies which metric to use for validating the model and tuning ***m***. Avail
5656
Specifies the quantile to use when ***loss_function*** is "quantile".
5757

5858
#### calculate_custom_validation_error_function (default = None)
59-
An optional Python function that calculates validation error if ***validation_tuning_metric*** is "custom_function". Example:
59+
A Python function that calculates validation error if ***validation_tuning_metric*** is "custom_function". Example:
6060

6161
```
6262
def custom_validation_error_function(y, predictions, sample_weight, group):
6363
squared_errors = (y-predictions)**2
6464
return squared_errors.mean()
6565
```
6666

67+
#### calculate_custom_loss_function (default = None)
68+
A Python function that calculates loss if ***loss_function*** is "custom_function". Example:
69+
70+
```
71+
def custom_loss_function(y, predictions, sample_weight, group):
72+
squared_errors = (y-predictions)**2
73+
return squared_errors.mean()
74+
```
75+
76+
#### calculate_custom_negative_gradient_function (default = None)
77+
A Python function that calculates the negative gradient if ***loss_function*** is "custom_function". The negative gradient should be proportional to the negative of the first order differentiation of the custom loss function (***calculate_custom_loss_function***) with respect to the predictions. Example:
78+
79+
```
80+
def custom_negative_gradient_function(y, predictions, group):
81+
residuals = y-predictions
82+
return residuals
83+
```
84+
6785
## Method: fit(X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[], prioritized_predictors_indexes:List[int]=[], monotonic_constraints:List[int]=[], group:npt.ArrayLike = np.empty(0), interaction_constraints:List[int]=[])
6886

6987
***This method fits the model to data.***

aplr/aplr.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, loss_function:st
99
validation_ratio:float=0.2, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000,
1010
min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0,
1111
dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5,
12-
calculate_custom_validation_error_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None):
12+
calculate_custom_validation_error_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None,
13+
calculate_custom_loss_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], float]]=None,
14+
calculate_custom_negative_gradient_function:Optional[Callable[[npt.ArrayLike, npt.ArrayLike, npt.ArrayLike], npt.ArrayLike]]=None):
1315
self.m=m
1416
self.v=v
1517
self.random_state=random_state
@@ -28,6 +30,8 @@ def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, loss_function:st
2830
self.validation_tuning_metric=validation_tuning_metric
2931
self.quantile=quantile
3032
self.calculate_custom_validation_error_function=calculate_custom_validation_error_function
33+
self.calculate_custom_loss_function=calculate_custom_loss_function
34+
self.calculate_custom_negative_gradient_function=calculate_custom_negative_gradient_function
3135

3236
#Creating aplr_cpp and setting parameters
3337
self.APLRRegressor=aplr_cpp.APLRRegressor()
@@ -53,6 +57,8 @@ def __set_params_cpp(self):
5357
self.APLRRegressor.validation_tuning_metric=self.validation_tuning_metric
5458
self.APLRRegressor.quantile=self.quantile
5559
self.APLRRegressor.calculate_custom_validation_error_function=self.calculate_custom_validation_error_function
60+
self.APLRRegressor.calculate_custom_loss_function=self.calculate_custom_loss_function
61+
self.APLRRegressor.calculate_custom_negative_gradient_function=self.calculate_custom_negative_gradient_function
5662

5763
def fit(self, X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[], prioritized_predictors_indexes:List[int]=[], monotonic_constraints:List[int]=[], group:npt.ArrayLike = np.empty(0), interaction_constraints:List[int]=[]):
5864
self.__set_params_cpp()
@@ -123,7 +129,9 @@ def get_params(self, deep=True):
123129
"dispersion_parameter":self.dispersion_parameter,
124130
"validation_tuning_metric":self.validation_tuning_metric,
125131
"quantile":self.quantile,
126-
"calculate_custom_validation_error_function":self.calculate_custom_validation_error_function
132+
"calculate_custom_validation_error_function":self.calculate_custom_validation_error_function,
133+
"calculate_custom_loss_function":self.calculate_custom_loss_function,
134+
"calculate_custom_negative_gradient_function":self.calculate_custom_negative_gradient_function
127135
}
128136

129137
#For sklearn

cpp/APLRRegressor.h

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class APLRRegressor
104104
void throw_error_if_vector_contains_negative_values(const VectorXd &y, const std::string &error_message);
105105
void throw_error_if_response_is_not_greater_than_zero(const VectorXd &y, const std::string &error_message);
106106
void throw_error_if_dispersion_parameter_is_invalid();
107-
VectorXd differentiate_predictions();
107+
VectorXd differentiate_predictions_wrt_linear_predictor();
108108
void scale_training_observations_if_using_log_link_function();
109109
void revert_scaling_if_using_log_link_function();
110110
void cap_predictions_to_minmax_in_training(VectorXd &predictions);
@@ -144,13 +144,17 @@ class APLRRegressor
144144
std::string validation_tuning_metric;
145145
double quantile;
146146
std::function<double(const VectorXd &y, const VectorXd &predictions, const VectorXd &sample_weight, const VectorXi &group)> calculate_custom_validation_error_function;
147+
std::function<double(const VectorXd &y, const VectorXd &predictions, const VectorXd &sample_weight, const VectorXi &group)> calculate_custom_loss_function;
148+
std::function<VectorXd(const VectorXd &y, const VectorXd &predictions, const VectorXi &group)> calculate_custom_negative_gradient_function;
147149

148150
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string loss_function="mse",
149151
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,
150152
size_t reserved_terms_times_num_x=100, size_t bins=300,size_t verbosity=0,size_t max_interaction_level=1,size_t max_interactions=100000,
151153
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5,double dispersion_parameter=1.5,
152154
std::string validation_tuning_metric="default", double quantile=0.5,
153-
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_validation_error_function={});
155+
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_validation_error_function={},
156+
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_loss_function={},
157+
const std::function<VectorXd(VectorXd,VectorXd,VectorXi)> &calculate_custom_negative_gradient_function={});
154158
APLRRegressor(const APLRRegressor &other);
155159
~APLRRegressor();
156160
void fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0),const std::vector<std::string> &X_names={},
@@ -177,15 +181,18 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
177181
double validation_ratio,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
178182
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double dispersion_parameter,
179183
std::string validation_tuning_metric, double quantile,
180-
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_validation_error_function):
184+
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_validation_error_function,
185+
const std::function<double(VectorXd,VectorXd,VectorXd,VectorXi)> &calculate_custom_loss_function,
186+
const std::function<VectorXd(VectorXd,VectorXd,VectorXi)> &calculate_custom_negative_gradient_function):
181187
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{NAN_DOUBLE},m{m},v{v},
182188
loss_function{loss_function},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
183189
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},intercept_steps{VectorXd(0)},
184190
max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
185191
min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
186192
max_eligible_terms{max_eligible_terms},number_of_base_terms{0},dispersion_parameter{dispersion_parameter},min_training_prediction_or_response{NAN_DOUBLE},
187193
max_training_prediction_or_response{NAN_DOUBLE}, validation_tuning_metric{validation_tuning_metric},
188-
validation_indexes{std::vector<size_t>(0)}, quantile{quantile}, calculate_custom_validation_error_function{calculate_custom_validation_error_function}
194+
validation_indexes{std::vector<size_t>(0)}, quantile{quantile}, calculate_custom_validation_error_function{calculate_custom_validation_error_function},
195+
calculate_custom_loss_function{calculate_custom_loss_function},calculate_custom_negative_gradient_function{calculate_custom_negative_gradient_function}
189196
{
190197
}
191198

@@ -201,7 +208,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
201208
feature_importance{other.feature_importance},dispersion_parameter{other.dispersion_parameter},min_training_prediction_or_response{other.min_training_prediction_or_response},
202209
max_training_prediction_or_response{other.max_training_prediction_or_response},validation_tuning_metric{other.validation_tuning_metric},
203210
validation_indexes{other.validation_indexes}, quantile{other.quantile}, m_optimal{other.m_optimal},
204-
calculate_custom_validation_error_function{other.calculate_custom_validation_error_function}
211+
calculate_custom_validation_error_function{other.calculate_custom_validation_error_function},
212+
calculate_custom_loss_function{other.calculate_custom_loss_function},calculate_custom_negative_gradient_function{other.calculate_custom_negative_gradient_function}
205213
{
206214
}
207215

@@ -258,6 +266,8 @@ void APLRRegressor::throw_error_if_loss_function_does_not_exist()
258266
loss_function_exists=true;
259267
else if(loss_function=="weibull")
260268
loss_function_exists=true;
269+
else if(loss_function=="custom_function")
270+
loss_function_exists=true;
261271
if(!loss_function_exists)
262272
throw std::runtime_error("Loss function "+loss_function+" is not available in APLR.");
263273
}
@@ -691,16 +701,28 @@ VectorXd APLRRegressor::calculate_neg_gradient_current(const VectorXd &sample_we
691701
}
692702
else if(loss_function=="weibull")
693703
{
694-
output= dispersion_parameter / predictions_current.array() * ( (y_train.array()/predictions_current.array()).pow(dispersion_parameter) - 1);
704+
output=dispersion_parameter / predictions_current.array() * ( (y_train.array()/predictions_current.array()).pow(dispersion_parameter) - 1);
695705
}
706+
else if(loss_function=="custom_function")
707+
{
708+
try
709+
{
710+
output=calculate_custom_negative_gradient_function(y_train, predictions_current, group_train);
711+
}
712+
catch(const std::exception& e)
713+
{
714+
std::string error_msg{"Error when calculating custom negative gradient function: " + static_cast<std::string>(e.what())};
715+
throw std::runtime_error(error_msg);
716+
}
717+
}
696718

697719
if(link_function!="identity")
698-
output=output.array()*differentiate_predictions().array();
720+
output=output.array()*differentiate_predictions_wrt_linear_predictor().array();
699721

700722
return output;
701723
}
702724

703-
VectorXd APLRRegressor::differentiate_predictions()
725+
VectorXd APLRRegressor::differentiate_predictions_wrt_linear_predictor()
704726
{
705727
if(link_function=="logit")
706728
return 1.0/4.0 * (linear_predictor_current.array()/2.0).cosh().array().pow(-2);
@@ -1145,7 +1167,22 @@ void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step
11451167
void APLRRegressor::calculate_validation_error(size_t boosting_step, const VectorXd &predictions)
11461168
{
11471169
if(validation_tuning_metric=="default")
1148-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,loss_function,dispersion_parameter,group_validation,unique_groups_validation,quantile),sample_weight_validation);
1170+
{
1171+
if(loss_function=="custom_function")
1172+
{
1173+
try
1174+
{
1175+
validation_error_steps[boosting_step] = calculate_custom_loss_function(y_validation, predictions, sample_weight_validation, group_validation);
1176+
}
1177+
catch(const std::exception& e)
1178+
{
1179+
std::string error_msg{"Error when calculating custom loss function: " + static_cast<std::string>(e.what())};
1180+
throw std::runtime_error(error_msg);
1181+
}
1182+
}
1183+
else
1184+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,loss_function,dispersion_parameter,group_validation,unique_groups_validation,quantile),sample_weight_validation);
1185+
}
11491186
else if(validation_tuning_metric=="mse")
11501187
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,MSE_LOSS_FUNCTION),sample_weight_validation);
11511188
else if(validation_tuning_metric=="mae")
@@ -1169,7 +1206,7 @@ void APLRRegressor::calculate_validation_error(size_t boosting_step, const Vecto
11691206
}
11701207
catch(const std::exception& e)
11711208
{
1172-
std::string error_msg{"Error when calculating custom validation error: " + static_cast<std::string>(e.what())};
1209+
std::string error_msg{"Error when calculating custom validation error function: " + static_cast<std::string>(e.what())};
11731210
throw std::runtime_error(error_msg);
11741211
}
11751212
}

0 commit comments

Comments
 (0)