Skip to content

Commit 941b7e3

Browse files
author
ZebinYang
committed
change glmtree classifier default predict function from pred proba to decision function; version 0.2.7
1 parent 4dfac23 commit 941b7e3

File tree

3 files changed

+17
-17
lines changed

3 files changed

+17
-17
lines changed

examples/demo.ipynb

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
"execution_count": null,
66
"metadata": {
77
"ExecuteTime": {
8-
"end_time": "2021-12-27T11:25:01.467037Z",
9-
"start_time": "2021-12-27T11:24:59.706592Z"
8+
"end_time": "2021-12-27T11:41:27.070305Z",
9+
"start_time": "2021-12-27T11:41:25.800526Z"
1010
}
1111
},
1212
"outputs": [],
@@ -32,8 +32,8 @@
3232
"execution_count": null,
3333
"metadata": {
3434
"ExecuteTime": {
35-
"end_time": "2021-12-27T11:25:01.500112Z",
36-
"start_time": "2021-12-27T11:25:01.469302Z"
35+
"end_time": "2021-12-27T11:41:27.095134Z",
36+
"start_time": "2021-12-27T11:41:27.071798Z"
3737
}
3838
},
3939
"outputs": [],
@@ -48,8 +48,8 @@
4848
"execution_count": null,
4949
"metadata": {
5050
"ExecuteTime": {
51-
"end_time": "2021-12-27T11:25:50.669194Z",
52-
"start_time": "2021-12-27T11:25:01.501343Z"
51+
"end_time": "2021-12-27T11:43:04.465481Z",
52+
"start_time": "2021-12-27T11:41:27.096046Z"
5353
},
5454
"scrolled": true
5555
},
@@ -71,8 +71,8 @@
7171
"execution_count": null,
7272
"metadata": {
7373
"ExecuteTime": {
74-
"end_time": "2021-12-27T11:25:50.670794Z",
75-
"start_time": "2021-12-27T11:25:50.670779Z"
74+
"end_time": "2021-12-27T11:43:05.991216Z",
75+
"start_time": "2021-12-27T11:43:04.467280Z"
7676
}
7777
},
7878
"outputs": [],
@@ -97,8 +97,8 @@
9797
"execution_count": null,
9898
"metadata": {
9999
"ExecuteTime": {
100-
"end_time": "2021-12-27T11:25:50.671677Z",
101-
"start_time": "2021-12-27T11:25:50.671662Z"
100+
"end_time": "2021-12-27T11:43:05.992967Z",
101+
"start_time": "2021-12-27T11:43:05.992948Z"
102102
}
103103
},
104104
"outputs": [],
@@ -113,8 +113,8 @@
113113
"execution_count": null,
114114
"metadata": {
115115
"ExecuteTime": {
116-
"end_time": "2021-12-27T11:25:50.672528Z",
117-
"start_time": "2021-12-27T11:25:50.672513Z"
116+
"end_time": "2021-12-27T11:43:05.994028Z",
117+
"start_time": "2021-12-27T11:43:05.994010Z"
118118
}
119119
},
120120
"outputs": [],
@@ -135,8 +135,8 @@
135135
"execution_count": null,
136136
"metadata": {
137137
"ExecuteTime": {
138-
"end_time": "2021-12-27T11:25:50.673375Z",
139-
"start_time": "2021-12-27T11:25:50.673360Z"
138+
"end_time": "2021-12-27T11:43:05.995089Z",
139+
"start_time": "2021-12-27T11:43:05.995071Z"
140140
}
141141
},
142142
"outputs": [],

simtree/glmtree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def build_leaf(self, sample_indice):
4343
nx = (self.x[sample_indice] - mx) / sx
4444

4545
if len(self.reg_lambda) > 1:
46-
best_estimator = LassoCV(alphas=self.reg_lambda, cv=5, n_jobs=self.n_jobs, precompute=False, random_state=self.random_state)
46+
best_estimator = LassoCV(alphas=self.reg_lambda, cv=5, precompute=False, random_state=self.random_state)
4747
best_estimator.fit(nx, self.y[sample_indice], self.sample_weight[sample_indice])
4848
else:
4949
if self.reg_lambda[0] > 0:
@@ -93,7 +93,7 @@ def build_leaf(self, sample_indice):
9393
else:
9494
if len(self.reg_lambda) > 1:
9595
best_estimator = LogisticRegressionCV(Cs=self.reg_lambda, penalty="l1", solver="liblinear", scoring="roc_auc",
96-
cv=5, n_jobs=self.n_jobs, random_state=self.random_state)
96+
cv=5, random_state=self.random_state)
9797
else:
9898
best_estimator = LogisticRegression(alpha=self.reg_lambda[0], precompute=False, random_state=self.random_state)
9999

simtree/mobtree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def fit(self, x, y):
256256
if self.split_features is None:
257257
self.split_features = np.arange(n_features).tolist()
258258

259-
if self.n_feature_search > len(self.split_features) and (self.max_depth >= 1):
259+
if self.n_feature_search > len(self.split_features) and (self.max_depth >= 0):
260260
self.important_split_features = self.split_features
261261
else:
262262
self.important_split_features = self.screen_features(sample_indice)

0 commit comments

Comments
 (0)