|
16 | 16 | from sklearn.utils.extmath import safe_sparse_dot |
17 | 17 | from sklearn.utils.validation import check_is_fitted |
18 | 18 | from sklearn.utils.validation import check_X_y |
| 19 | +from sklearn.base import clone |
19 | 20 |
|
20 | 21 | from .._typing import Float2D |
21 | 22 | from .._typing import FloatDType |
@@ -175,29 +176,19 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws): |
175 | 176 |
|
176 | 177 | # The next scope is for when the sample weights |
177 | 178 | # are different for each output component |
178 | | - if sample_weight is not None: |
179 | | - sample_weight = np.asarray(sample_weight) |
180 | | - if sample_weight.shape == y.shape: |
181 | | - # Fit separately per target with its own weights, then combine |
182 | | - coefs, histories = [], [] |
183 | | - for j in range(y.shape[1]): |
184 | | - sw_j = sample_weight[:, j] |
185 | | - # recursive call on 1D y[:, j] |
186 | | - sub = self.__class__( |
187 | | - alpha=self.alpha, |
188 | | - threshold=self.threshold, |
189 | | - normalize_columns=self.normalize_columns, |
190 | | - unbias=self.unbias, |
191 | | - max_iter=self.max_iter, |
192 | | - copy_X=self.copy_X, |
193 | | - initial_guess=None, |
194 | | - ) |
195 | | - sub.fit(x_, y[:, j], sample_weight=sw_j, **reduce_kws) |
196 | | - coefs.append(sub.coef_.ravel()) |
197 | | - histories.append(sub.history_) |
198 | | - self.coef_ = np.column_stack(coefs) |
199 | | - self.history_ = histories |
200 | | - return self |
| 179 | + # we select the weights of the each output component and |
| 180 | + # recursively fit |
| 181 | + if sample_weight.shape == y.shape: |
| 182 | + coefs, histories = [], [] |
| 183 | + for j in range(y.shape[1]): |
| 184 | + sw_j = sample_weight[:, j] |
| 185 | + sub = clone(self) |
| 186 | + sub.fit(x_, y[:, j], sample_weight=sw_j, **reduce_kws) |
| 187 | + coefs.append(sub.coef_.ravel()) |
| 188 | + histories.append(sub.history_) |
| 189 | + self.coef_ = np.column_stack(coefs) |
| 190 | + self.history_ = histories |
| 191 | + return self |
201 | 192 |
|
202 | 193 | x, y, X_offset, y_offset, X_scale = _preprocess_data( |
203 | 194 | x_, |
|
0 commit comments