Skip to content

Commit c1229d9

Browse files
authored
[Enhancement] Upgrade IncrementalLinearRegression for underdetermined systems (uxlfoundation#2175)
* Update incremental_linear.py * Update incremental_linear.py * Update incremental_linear.py * formatting * add import * Update incremental_linear.py * Update deselected_tests.yaml * Update incremental_linear.py * formatting * Update incremental_linear.py
1 parent c631127 commit c1229d9

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

deselected_tests.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ deselected_tests:
299299
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle]
300300
- tests/test_common.py::test_estimators[IncrementalRidge()-check_estimators_pickle(readonly_memmap=True)]
301301
# There are not enough data to run onedal backend
302-
- tests/test_common.py::test_estimators[IncrementalLinearRegression()-check_fit2d_1sample]
303302
- tests/test_common.py::test_estimators[IncrementalRidge()-check_fit2d_1sample]
304303

305304
# Deselection of LogisticRegression tests over accuracy comparisons with sample_weights

sklearnex/linear_model/incremental_linear.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.utils.validation import check_is_fitted
2525

2626
from daal4py.sklearn._n_jobs_support import control_n_jobs
27-
from daal4py.sklearn._utils import sklearn_check_version
27+
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2828
from onedal.linear_model import (
2929
IncrementalLinearRegression as onedal_IncrementalLinearRegression,
3030
)
@@ -221,13 +221,21 @@ def _onedal_partial_fit(self, X, y, check_input=True, queue=None):
221221
self._onedal_estimator.partial_fit(X, y, queue=queue)
222222
self._need_to_finalize = True
223223

224+
if daal_check_version((2025, "P", 200)):
225+
226+
def _onedal_validate_underdetermined(self, n_samples, n_features):
227+
pass
228+
229+
else:
230+
231+
def _onedal_validate_underdetermined(self, n_samples, n_features):
232+
is_underdetermined = n_samples < n_features + int(self.fit_intercept)
233+
if is_underdetermined:
234+
raise ValueError("Not enough samples for oneDAL")
235+
224236
def _onedal_finalize_fit(self, queue=None):
225237
assert hasattr(self, "_onedal_estimator")
226-
is_underdetermined = self.n_samples_seen_ < self.n_features_in_ + int(
227-
self.fit_intercept
228-
)
229-
if is_underdetermined:
230-
raise ValueError("Not enough samples to finalize")
238+
self._onedal_validate_underdetermined(self.n_samples_seen_, self.n_features_in_)
231239
self._onedal_estimator.finalize_fit(queue=queue)
232240
self._need_to_finalize = False
233241

@@ -260,9 +268,7 @@ def _onedal_fit(self, X, y, queue=None):
260268

261269
n_samples, n_features = X.shape
262270

263-
is_underdetermined = n_samples < n_features + int(self.fit_intercept)
264-
if is_underdetermined:
265-
raise ValueError("Not enough samples to run oneDAL backend")
271+
self._onedal_validate_underdetermined(n_samples, n_features)
266272

267273
if self.batch_size is None:
268274
self.batch_size_ = 5 * n_features

0 commit comments

Comments
 (0)