|
3 | 3 | from sklearn.utils.validation import check_is_fitted |
4 | 4 | from sklearn.model_selection import GridSearchCV |
5 | 5 | from sklearn.metrics import make_scorer, roc_auc_score, mean_squared_error |
6 | | -from sklearn.base import RegressorMixin, ClassifierMixin, is_regressor, is_classifier |
| 6 | +from sklearn.base import RegressorMixin, ClassifierMixin, is_regressor, is_classifier, clone |
7 | 7 |
|
8 | 8 | from .mobtree import MoBTreeRegressor, MoBTreeClassifier |
9 | 9 |
|
@@ -42,11 +42,23 @@ def build_root(self): |
42 | 42 |
|
43 | 43 | def build_leaf(self, sample_indice): |
44 | 44 |
|
45 | | - grid = GridSearchCV(self.base_estimator, param_grid=self.param_dict, |
46 | | - scoring={"mse": make_scorer(mean_squared_error, greater_is_better=False)}, |
47 | | - cv=5, refit="mse", n_jobs=1, error_score=np.nan) |
48 | | - grid.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
49 | | - best_estimator = grid.best_estimator_ |
| 45 | + if len(self.param_dict) == 0: |
| 46 | + self.base_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
| 47 | + best_estimator = self.base_estimator |
| 48 | + else: |
| 49 | + param_size = 1 |
| 50 | + for key, item in self.param_dict.items(): |
| 51 | + param_size *= len(item) |
| 52 | + if param_size == 1: |
| 53 | + self.base_estimator.set_params(**{key: item[0] for key, item in self.param_dict.items()}) |
| 54 | + self.base_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
| 55 | + best_estimator = self.base_estimator |
| 56 | + else: |
| 57 | + grid = GridSearchCV(self.base_estimator, param_grid=self.param_dict, |
| 58 | + scoring={"mse": make_scorer(mean_squared_error, greater_is_better=False)}, |
| 59 | + cv=5, refit="mse", n_jobs=1, error_score=np.nan) |
| 60 | + grid.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
| 61 | + best_estimator = grid.best_estimator_ |
50 | 62 | predict_func = lambda x: best_estimator.predict(x) |
51 | 63 | best_impurity = self.get_loss(self.y[sample_indice], best_estimator.predict(self.x[sample_indice])) |
52 | 64 | return predict_func, best_estimator, best_impurity |
@@ -80,16 +92,33 @@ def build_root(self): |
80 | 92 |
|
81 | 93 | def build_leaf(self, sample_indice): |
82 | 94 |
|
83 | | - if (self.y[sample_indice].std() == 0) | (self.y[sample_indice].sum() < 5) | ((1 - self.y[sample_indice]).sum() < 5): |
84 | | - best_estimator = None |
85 | | - predict_func = lambda x: np.ones(x.shape[0]) * self.y[sample_indice].mean() |
86 | | - best_impurity = self.get_loss(self.y[sample_indice], predict_func(self.x[sample_indice])) |
87 | | - else: |
88 | | - grid = GridSearchCV(self.base_estimator, param_grid=self.param_dict, |
89 | | - scoring={"auc": make_scorer(roc_auc_score, needs_proba=True)}, |
90 | | - cv=5, refit="auc", n_jobs=1, error_score=np.nan) |
91 | | - grid.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
92 | | - best_estimator = grid.best_estimator_ |
| 95 | + if len(self.param_dict) == 0: |
| 96 | + best_estimator = clone(self.base_estimator) |
| 97 | + best_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
93 | 98 | predict_func = lambda x: best_estimator.decision_function(x) |
94 | 99 | best_impurity = self.get_loss(self.y[sample_indice], best_estimator.predict_proba(self.x[sample_indice])[:, 1]) |
| 100 | + else: |
| 101 | + param_size = 1 |
| 102 | + for key, item in self.param_dict.items(): |
| 103 | + param_size *= len(item) |
| 104 | + if param_size == 1: |
| 105 | + best_estimator = clone(self.base_estimator) |
| 106 | + best_estimator.set_params(**{key: item[0] for key, item in self.param_dict.items()}) |
| 107 | + best_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
| 108 | + predict_func = lambda x: best_estimator.decision_function(x) |
| 109 | + best_impurity = self.get_loss(self.y[sample_indice], best_estimator.predict_proba(self.x[sample_indice])[:, 1]) |
| 110 | + else: |
| 111 | + if (self.y[sample_indice].std() == 0) | (self.y[sample_indice].sum() < 5) | ((1 - self.y[sample_indice]).sum() < 5): |
| 112 | + best_estimator = None |
| 113 | + predict_func = lambda x: np.ones(x.shape[0]) * self.y[sample_indice].mean() |
| 114 | + best_impurity = self.get_loss(self.y[sample_indice], predict_func(self.x[sample_indice])) |
| 115 | + else: |
| 116 | + grid = GridSearchCV(self.base_estimator, param_grid=self.param_dict, |
| 117 | + scoring={"auc": make_scorer(roc_auc_score, needs_proba=True)}, |
| 118 | + cv=5, refit="auc", n_jobs=1, error_score=np.nan) |
| 119 | + grid.fit(self.x[sample_indice], self.y[sample_indice].ravel()) |
| 120 | + best_estimator = grid.best_estimator_ |
| 121 | + |
| 122 | + predict_func = lambda x: best_estimator.decision_function(x) |
| 123 | + best_impurity = self.get_loss(self.y[sample_indice], best_estimator.predict_proba(self.x[sample_indice])[:, 1]) |
95 | 124 | return predict_func, best_estimator, best_impurity |
0 commit comments