Skip to content

Commit 1a2a5ec

Browse files
Adjusted fit recursion for multi component weights
1 parent 2e983a4 commit 1a2a5ec

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

pysindy/optimizers/base.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.utils.extmath import safe_sparse_dot
1717
from sklearn.utils.validation import check_is_fitted
1818
from sklearn.utils.validation import check_X_y
19+
from sklearn.base import clone
1920

2021
from .._typing import Float2D
2122
from .._typing import FloatDType
@@ -175,29 +176,19 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
175176

176177
# The next scope is for when the sample weights
177178
# are different for each output component
178-
if sample_weight is not None:
179-
sample_weight = np.asarray(sample_weight)
180-
if sample_weight.shape == y.shape:
181-
# Fit separately per target with its own weights, then combine
182-
coefs, histories = [], []
183-
for j in range(y.shape[1]):
184-
sw_j = sample_weight[:, j]
185-
# recursive call on 1D y[:, j]
186-
sub = self.__class__(
187-
alpha=self.alpha,
188-
threshold=self.threshold,
189-
normalize_columns=self.normalize_columns,
190-
unbias=self.unbias,
191-
max_iter=self.max_iter,
192-
copy_X=self.copy_X,
193-
initial_guess=None,
194-
)
195-
sub.fit(x_, y[:, j], sample_weight=sw_j, **reduce_kws)
196-
coefs.append(sub.coef_.ravel())
197-
histories.append(sub.history_)
198-
self.coef_ = np.column_stack(coefs)
199-
self.history_ = histories
200-
return self
179+
# we select the weights of the each output component and
180+
# recursively fit
181+
if sample_weight.shape == y.shape:
182+
coefs, histories = [], []
183+
for j in range(y.shape[1]):
184+
sw_j = sample_weight[:, j]
185+
sub = clone(self)
186+
sub.fit(x_, y[:, j], sample_weight=sw_j, **reduce_kws)
187+
coefs.append(sub.coef_.ravel())
188+
histories.append(sub.history_)
189+
self.coef_ = np.column_stack(coefs)
190+
self.history_ = histories
191+
return self
201192

202193
x, y, X_offset, y_offset, X_scale = _preprocess_data(
203194
x_,

0 commit comments

Comments
 (0)