Skip to content

Commit f29235d

Browse files
committed
allow tree shrinkage in gam
1 parent b123d26 commit f29235d

File tree

3 files changed

+204
-93
lines changed

3 files changed

+204
-93
lines changed

imodels/algebraic/gam.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,9 @@
77
from sklearn.utils.validation import check_is_fitted
88
from sklearn.utils import check_array
99
from sklearn.utils.multiclass import check_classification_targets
10-
from sklearn.utils.multiclass import type_of_target
1110
from sklearn.utils.validation import check_X_y
12-
from sklearn.utils.validation import check_random_state
13-
from sklearn.utils.validation import column_or_1d
14-
from sklearn.utils.validation import check_consistent_length
1511
from sklearn.utils.validation import _check_sample_weight
1612
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
17-
from sklearn.datasets import load_breast_cancer
1813
from sklearn.model_selection import train_test_split
1914
from sklearn.metrics import accuracy_score, roc_auc_score
2015
from tqdm import tqdm
@@ -34,17 +29,44 @@ def __init__(
3429
self,
3530
n_boosting_rounds=100,
3631
max_leaf_nodes=3,
32+
reg_param=0.0,
3733
n_boosting_rounds_marginal=0,
3834
max_leaf_nodes_marginal=2,
39-
fit_linear_marginal=False,
35+
reg_param_marginal=0.0,
36+
fit_linear_marginal=None,
4037
random_state=None,
4138
):
42-
self.max_leaf_nodes = max_leaf_nodes
43-
self.random_state = random_state
39+
"""
40+
Params
41+
------
42+
n_boosting_rounds : int
43+
Number of boosting rounds for the cyclic boosting.
44+
max_leaf_nodes : int
45+
Maximum number of leaf nodes for the trees in the cyclic boosting.
46+
reg_param : float
47+
Regularization parameter for the cyclic boosting.
48+
n_boosting_rounds_marginal : int
49+
Number of boosting rounds for the marginal boosting.
50+
max_leaf_nodes_marginal : int
51+
Maximum number of leaf nodes for the trees in the marginal boosting.
52+
reg_param_marginal : float
53+
Regularization parameter for the marginal boosting.
54+
fit_linear_marginal : str [None, "None", "ridge", "NNLS"]
55+
Whether to fit a linear model to the marginal effects.
56+
NNLS for non-negative least squares
57+
ridge for ridge regression
58+
None for no linear model
59+
random_state : int
60+
Random seed.
61+
"""
4462
self.n_boosting_rounds = n_boosting_rounds
63+
self.max_leaf_nodes = max_leaf_nodes
64+
self.reg_param = reg_param
4565
self.max_leaf_nodes_marginal = max_leaf_nodes_marginal
66+
self.reg_param_marginal = reg_param_marginal
4667
self.n_boosting_rounds_marginal = n_boosting_rounds_marginal
4768
self.fit_linear_marginal = fit_linear_marginal
69+
self.random_state = random_state
4870

4971
def fit(self, X, y, sample_weight=None, learning_rate=0.01, validation_frac=0.15):
5072
X, y = check_X_y(X, y, accept_sparse=False, multi_output=False)
@@ -110,11 +132,18 @@ def _marginal_fit(
110132
n_estimators=self.n_boosting_rounds_marginal,
111133
)
112134
est.fit(X_, residuals_train, sample_weight=sample_weight_train)
135+
if self.reg_param_marginal > 0:
136+
est = imodels.HSTreeRegressor(est, reg_param=self.reg_param_marginal)
113137
self.estimators_marginal.append(est)
114138

115-
if self.fit_linear_marginal:
116-
linear_marginal = RidgeCV(fit_intercept=False)
117-
# linear_marginal = LinearRegression(fit_intercept=False, positive=True)
139+
if (
140+
self.fit_linear_marginal is not None
141+
and not self.fit_linear_marginal == "None"
142+
):
143+
if self.fit_linear_marginal.lower() == "ridge":
144+
linear_marginal = RidgeCV(fit_intercept=False)
145+
elif self.fit_linear_marginal.lower() == "nnls":
146+
linear_marginal = LinearRegression(fit_intercept=False, positive=True)
118147
linear_marginal.fit(
119148
np.array([est.predict(X_train) for est in self.estimators_marginal]).T,
120149
residuals_train,
@@ -149,6 +178,8 @@ def _cyclic_boost(
149178
)
150179
if not succesfully_split_on_feature:
151180
continue
181+
if self.reg_param > 0:
182+
est = imodels.HSTreeRegressor(est, reg_param=self.reg_param)
152183
self.estimators_.append(est)
153184
residuals_train = residuals_train - self.learning_rate * est.predict(
154185
X_train

0 commit comments

Comments
 (0)