Skip to content

Commit 97eae34

Browse files
2 parents 9084812 + 4210995 commit 97eae34

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

pysindy/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def __init__(
298298
self.discrete_time = discrete_time
299299
self.set_fit_request(sample_weight=True)
300300
self.set_score_request(sample_weight=True)
301+
self.optimizer.set_fit_request(sample_weight=True)
301302

302303
def fit(
303304
self,

pysindy/optimizers/base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,32 @@ def fit(self, x_, y, sample_weight=None, **reduce_kws):
173173
x_, y = drop_nan_samples(x_, y)
174174
x_, y = check_X_y(x_, y, accept_sparse=[], y_numeric=True, multi_output=True)
175175

176+
# The next scope is for when the sample weights
177+
# are different for each output component
178+
if sample_weight is not None:
179+
sample_weight = np.asarray(sample_weight)
180+
if sample_weight.shape == y.shape:
181+
# Fit separately per target with its own weights, then combine
182+
coefs, histories = [], []
183+
for j in range(y.shape[1]):
184+
sw_j = sample_weight[:, j]
185+
# recursive call on 1D y[:, j]
186+
sub = self.__class__(
187+
alpha=self.alpha,
188+
threshold=self.threshold,
189+
normalize_columns=self.normalize_columns,
190+
unbias=self.unbias,
191+
max_iter=self.max_iter,
192+
copy_X=self.copy_X,
193+
initial_guess=None,
194+
)
195+
sub.fit(x_, y[:, j], sample_weight=sw_j, **reduce_kws)
196+
coefs.append(sub.coef_.ravel())
197+
histories.append(sub.history_)
198+
self.coef_ = np.column_stack(coefs)
199+
self.history_ = histories
200+
return self
201+
176202
x, y, X_offset, y_offset, X_scale = _preprocess_data(
177203
x_,
178204
y,

0 commit comments

Comments
 (0)