17
17
TransformerMixin ,
18
18
clone ,
19
19
is_classifier ,
20
+ is_regressor ,
20
21
)
21
22
from sklearn .linear_model import LogisticRegression
22
23
from sklearn .metrics import check_scoring
23
24
from sklearn .model_selection import KFold , StratifiedKFold , check_cv
24
- from sklearn .utils import check_array , check_X_y , indexable
25
+ from sklearn .utils import indexable
25
26
from sklearn .utils .validation import check_is_fitted
26
27
27
28
from ..parallel import parallel_func
28
- from ..utils import _check_option , _pl , _validate_type , logger , pinv , verbose , warn
29
+ from ..utils import (
30
+ _check_option ,
31
+ _pl ,
32
+ _validate_type ,
33
+ logger ,
34
+ pinv ,
35
+ verbose ,
36
+ warn ,
37
+ )
38
+ from ._fixes import validate_data
29
39
from ._ged import (
30
40
_handle_restr_mat ,
31
41
_is_cov_pos_semidef ,
@@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
340
350
model : object | None
341
351
A linear model from scikit-learn with a fit method
342
352
that updates a ``coef_`` attribute.
343
- If None the model will be LogisticRegression.
353
+ If None the model will be
354
+ :class:`sklearn.linear_model.LogisticRegression`.
344
355
345
356
Attributes
346
357
----------
@@ -364,46 +375,66 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
364
375
.. footbibliography::
365
376
"""
366
377
367
- # TODO: Properly refactor this using
368
- # https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
369
378
_model_attr_wrap = (
370
379
"transform" ,
380
+ "fit_transform" ,
371
381
"predict" ,
372
382
"predict_proba" ,
373
- "_estimator_type " ,
374
- "__tags__" ,
383
+ "predict_log_proba " ,
384
+ "_estimator_type" , # remove after sklearn 1.6
375
385
"decision_function" ,
376
386
"score" ,
377
387
"classes_" ,
378
388
)
379
389
380
390
def __init__ (self , model = None ):
381
- # TODO: We need to set this to get our tag checking to work properly
382
- if model is None :
383
- model = LogisticRegression (solver = "liblinear" )
384
391
self .model = model
385
392
386
393
def __sklearn_tags__ (self ):
387
394
"""Get sklearn tags."""
388
- from sklearn .utils import get_tags # added in 1.6
389
-
390
- # fit method below does not allow sparse data via check_data, we could
391
- # eventually make it smarter if we had to
392
- tags = get_tags (self .model )
393
- tags .input_tags .sparse = False
395
+ tags = super ().__sklearn_tags__ ()
396
+ model = self .model if self .model is not None else LogisticRegression ()
397
+ model_tags = model .__sklearn_tags__ ()
398
+ tags .estimator_type = model_tags .estimator_type
399
+ if tags .estimator_type is not None :
400
+ model_type_tags = getattr (model_tags , f"{ tags .estimator_type } _tags" )
401
+ setattr (tags , f"{ tags .estimator_type } _tags" , model_type_tags )
394
402
return tags
395
403
396
404
def __getattr__ (self , attr ):
397
405
"""Wrap to model for some attributes."""
398
406
if attr in LinearModel ._model_attr_wrap :
399
- return getattr (self .model , attr )
400
- elif attr == "fit_transform" and hasattr (self .model , "fit_transform" ):
401
- return super ().__getattr__ (self , "_fit_transform" )
402
- return super ().__getattr__ (self , attr )
407
+ model = self .model_ if "model_" in self .__dict__ else self .model
408
+ if attr == "fit_transform" and hasattr (model , "fit_transform" ):
409
+ return self ._fit_transform
410
+ else :
411
+ return getattr (model , attr )
412
+ else :
413
+ raise AttributeError (
414
+ f"'{ type (self ).__name__ } ' object has no attribute '{ attr } '"
415
+ )
403
416
404
417
def _fit_transform (self , X , y ):
405
418
return self .fit (X , y ).transform (X )
406
419
420
+ def _validate_params (self , X ):
421
+ if self .model is not None :
422
+ model = self .model
423
+ if isinstance (model , MetaEstimatorMixin ):
424
+ model = model .estimator
425
+ is_predictor = is_regressor (model ) or is_classifier (model )
426
+ if not is_predictor :
427
+ raise ValueError (
428
+ "Linear model should be a supervised predictor "
429
+ "(classifier or regressor)"
430
+ )
431
+
432
+ # For sklearn < 1.6
433
+ try :
434
+ self ._check_n_features (X , reset = True )
435
+ except AttributeError :
436
+ pass
437
+
407
438
def fit (self , X , y , ** fit_params ):
408
439
"""Estimate the coefficients of the linear model.
409
440
@@ -424,25 +455,18 @@ def fit(self, X, y, **fit_params):
424
455
self : instance of LinearModel
425
456
Returns the modified instance.
426
457
"""
427
- if y is not None :
428
- X = check_array (X )
429
- else :
430
- X , y = check_X_y (X , y )
431
- self .n_features_in_ = X .shape [1 ]
432
- if y is not None :
433
- y = check_array (y , dtype = None , ensure_2d = False , input_name = "y" )
434
- if y .ndim > 2 :
435
- raise ValueError (
436
- f"LinearModel only accepts up to 2-dimensional y, got { y .shape } "
437
- "instead."
438
- )
458
+ self ._validate_params (X )
459
+ X , y = validate_data (self , X , y , multi_output = True )
439
460
440
461
# fit the Model
441
- self .model .fit (X , y , ** fit_params )
442
- self .model_ = self .model # for better sklearn compat
462
+ self .model_ = (
463
+ clone (self .model )
464
+ if self .model is not None
465
+ else LogisticRegression (solver = "liblinear" )
466
+ )
467
+ self .model_ .fit (X , y , ** fit_params )
443
468
444
469
# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y
445
-
446
470
inv_Y = 1.0
447
471
X = X - X .mean (0 , keepdims = True )
448
472
if y .ndim == 2 and y .shape [1 ] != 1 :
@@ -454,12 +478,17 @@ def fit(self, X, y, **fit_params):
454
478
455
479
@property
456
480
def filters_ (self ):
457
- if hasattr (self .model , "coef_" ):
481
+ if hasattr (self .model_ , "coef_" ):
458
482
# Standard Linear Model
459
- filters = self .model .coef_
460
- elif hasattr (self .model .best_estimator_ , "coef_" ):
483
+ filters = self .model_ .coef_
484
+ elif hasattr (self .model_ , "estimators_" ):
485
+ # Linear model with OneVsRestClassifier
486
+ filters = np .vstack ([est .coef_ for est in self .model_ .estimators_ ])
487
+ elif hasattr (self .model_ , "best_estimator_" ) and hasattr (
488
+ self .model_ .best_estimator_ , "coef_"
489
+ ):
461
490
# Linear Model with GridSearchCV
462
- filters = self .model .best_estimator_ .coef_
491
+ filters = self .model_ .best_estimator_ .coef_
463
492
else :
464
493
raise ValueError ("model does not have a `coef_` attribute." )
465
494
if filters .ndim == 2 and filters .shape [0 ] == 1 :
0 commit comments