77from sklearn .utils .validation import check_is_fitted
88from sklearn .utils import check_array
99from sklearn .utils .multiclass import check_classification_targets
10- from sklearn .utils .multiclass import type_of_target
1110from sklearn .utils .validation import check_X_y
12- from sklearn .utils .validation import check_random_state
13- from sklearn .utils .validation import column_or_1d
14- from sklearn .utils .validation import check_consistent_length
1511from sklearn .utils .validation import _check_sample_weight
1612from sklearn .ensemble import GradientBoostingClassifier , GradientBoostingRegressor
17- from sklearn .datasets import load_breast_cancer
1813from sklearn .model_selection import train_test_split
1914from sklearn .metrics import accuracy_score , roc_auc_score
2015from tqdm import tqdm
@@ -34,17 +29,44 @@ def __init__(
3429 self ,
3530 n_boosting_rounds = 100 ,
3631 max_leaf_nodes = 3 ,
32+ reg_param = 0.0 ,
3733 n_boosting_rounds_marginal = 0 ,
3834 max_leaf_nodes_marginal = 2 ,
39- fit_linear_marginal = False ,
35+ reg_param_marginal = 0.0 ,
36+ fit_linear_marginal = None ,
4037 random_state = None ,
4138 ):
42- self .max_leaf_nodes = max_leaf_nodes
43- self .random_state = random_state
39+ """
40+ Params
41+ ------
42+ n_boosting_rounds : int
43+ Number of boosting rounds for the cyclic boosting.
44+ max_leaf_nodes : int
45+ Maximum number of leaf nodes for the trees in the cyclic boosting.
46+ reg_param : float
47+ Regularization parameter for the cyclic boosting.
48+ n_boosting_rounds_marginal : int
49+ Number of boosting rounds for the marginal boosting.
50+ max_leaf_nodes_marginal : int
51+ Maximum number of leaf nodes for the trees in the marginal boosting.
52+ reg_param_marginal : float
53+ Regularization parameter for the marginal boosting.
54+ fit_linear_marginal : str [None, "None", "ridge", "NNLS"]
55+ Whether to fit a linear model to the marginal effects.
56+ NNLS for non-negative least squares
57+ ridge for ridge regression
58+ None for no linear model
59+ random_state : int
60+ Random seed.
61+ """
4462 self .n_boosting_rounds = n_boosting_rounds
63+ self .max_leaf_nodes = max_leaf_nodes
64+ self .reg_param = reg_param
4565 self .max_leaf_nodes_marginal = max_leaf_nodes_marginal
66+ self .reg_param_marginal = reg_param_marginal
4667 self .n_boosting_rounds_marginal = n_boosting_rounds_marginal
4768 self .fit_linear_marginal = fit_linear_marginal
69+ self .random_state = random_state
4870
4971 def fit (self , X , y , sample_weight = None , learning_rate = 0.01 , validation_frac = 0.15 ):
5072 X , y = check_X_y (X , y , accept_sparse = False , multi_output = False )
@@ -110,11 +132,18 @@ def _marginal_fit(
110132 n_estimators = self .n_boosting_rounds_marginal ,
111133 )
112134 est .fit (X_ , residuals_train , sample_weight = sample_weight_train )
135+ if self .reg_param_marginal > 0 :
136+ est = imodels .HSTreeRegressor (est , reg_param = self .reg_param_marginal )
113137 self .estimators_marginal .append (est )
114138
115- if self .fit_linear_marginal :
116- linear_marginal = RidgeCV (fit_intercept = False )
117- # linear_marginal = LinearRegression(fit_intercept=False, positive=True)
139+ if (
140+ self .fit_linear_marginal is not None
141+ and not self .fit_linear_marginal == "None"
142+ ):
143+ if self .fit_linear_marginal .lower () == "ridge" :
144+ linear_marginal = RidgeCV (fit_intercept = False )
145+ elif self .fit_linear_marginal .lower () == "nnls" :
146+ linear_marginal = LinearRegression (fit_intercept = False , positive = True )
118147 linear_marginal .fit (
119148 np .array ([est .predict (X_train ) for est in self .estimators_marginal ]).T ,
120149 residuals_train ,
@@ -149,6 +178,8 @@ def _cyclic_boost(
149178 )
150179 if not succesfully_split_on_feature :
151180 continue
181+ if self .reg_param > 0 :
182+ est = imodels .HSTreeRegressor (est , reg_param = self .reg_param )
152183 self .estimators_ .append (est )
153184 residuals_train = residuals_train - self .learning_rate * est .predict (
154185 X_train
0 commit comments