22
33import copy
44import inspect
5+ import logging
56import math
67import pickle
78from collections .abc import Iterable
1415from polars .series .series import ArrayLike
1516from tqdm .auto import tqdm
1617
18+ from ._corr import correlation_matrix
1719from .metrics import roc_auc
1820
21+ logger = logging .getLogger (__name__ )
22+
1923
2024class Estimator (Protocol ):
2125 def fit (self , X , y , ** kwargs ): ...
@@ -179,7 +183,7 @@ def __init__(
179183 n_features_to_select : float = 1 ,
180184 step : float = 1 ,
181185 importance : Callable [[RFEState ], Iterable [float ]] = _rfe_get_feature_importance ,
182- callbacks : Optional [Iterable [Callable [[RFEState ]]]] = None ,
186+ callbacks : Optional [Iterable [Callable [[RFEState ], Any ]]] = None ,
183187 quiet : bool = False ,
184188 ):
185189 self .unfit_estimator = estimator
@@ -235,14 +239,14 @@ def fit(
235239 ** fit_kwargs ,
236240 )
237241
238- state = {
239- " estimator" : est ,
240- "X" : X_loop ,
241- "y" : y ,
242- " eval_set" : fit_kwargs .get ("eval_set" , None ),
243- " features" : features ,
244- " iteration" : iteration ,
245- }
242+ state = RFEState (
243+ estimator = est ,
244+ X = X_loop ,
245+ y = y ,
246+ eval_set = fit_kwargs .get ("eval_set" , None ),
247+ features = features ,
248+ iteration = iteration ,
249+ )
246250
247251 for callback in self .callbacks :
248252 callback (state )
@@ -261,6 +265,9 @@ def fit(
261265 real_step = _get_step (len_features , step )
262266 k = len_features - real_step
263267
268+ if k <= 0 :
269+ break
270+
264271 remaining_features = (
265272 pl .LazyFrame (
266273 {"importance" : self .importance (state ), "feature" : features }
@@ -277,38 +284,15 @@ def fit(
277284 pbar .update (1 )
278285
279286 self .estimator_ = est
280- self .selected_features_ = features
287+ self .selected_features_ = sorted ( features )
281288
282289 return self
283290
284- def transform (
285- self ,
286- X : Optional [nwt .IntoDataFrame ] = None ,
287- y : Optional [Any ] = None ,
288- ** fit_kwargs ,
289- ) -> Any :
290- if X is None or y is None :
291- return self .estimator_
292-
293- if "eval_set" in fit_kwargs :
294- fit_kwargs ["eval_set" ] = [
295- (
296- nw .from_native (X_val ).select (self .selected_features_ ).to_native (),
297- y_val ,
298- )
299- for X_val , y_val in fit_kwargs ["eval_set" ]
300- ]
301-
302- return self .unfit_estimator .fit (
303- nw .from_native (X , eager_only = True )
304- .select (self .selected_features_ )
305- .to_native (),
306- y ,
307- ** fit_kwargs ,
308- )
291+ def transform (self , X : nwt .IntoFrameT ) -> nwt .IntoFrameT :
292+ return nw .from_native (X ).select (self .selected_features_ ).to_native ()
309293
310294 def fit_transform (self , X , y , ** fit_kwargs ) -> Any :
311- return self .fit (X , y , ** fit_kwargs ).transform ()
295+ return self .fit (X , y , ** fit_kwargs ).transform (X )
312296
313297
314298class NFEState (TypedDict ):
@@ -328,7 +312,7 @@ def __init__(
328312 self ,
329313 estimator : Estimator ,
330314 importance : Callable [[NFEState ], ArrayLike ] = _nfe_get_feature_importance ,
331- seed : Optional [int ] = None ,
315+ seed : Optional [int ] = 208 ,
332316 ):
333317 self .unfit_estimator = estimator
334318 self .importance = importance
@@ -347,7 +331,6 @@ def _add_noise(self, df: nw.DataFrame) -> nw.DataFrame:
347331 )
348332
349333 def fit (self , X : nwt .IntoDataFrame , y : Any , ** fit_kwargs ):
350-
351334 X_nw = nw .from_native (X , eager_only = True ).pipe (self ._add_noise )
352335
353336 if "eval_set" in fit_kwargs :
@@ -364,7 +347,7 @@ def fit(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs):
364347 X_train = X_nw .to_native ()
365348 est = self .unfit_estimator .fit (X_train , y , ** fit_kwargs )
366349
367- state = { " estimator" : est , "X" : X_train , "y" : y }
350+ state = NFEState ( estimator = est , X = X_train , y = y )
368351
369352 nfe_features = (
370353 pl .LazyFrame (
@@ -377,35 +360,118 @@ def fit(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs):
377360 )
378361 )
379362 .collect ()["feature" ]
363+ .sort ()
380364 .to_list ()
381365 )
382366
383367 self .selected_features_ = nfe_features
384368
385369 return self
386370
387- def transform (
388- self ,
389- X : nwt .IntoDataFrame ,
390- y : Any ,
391- ** fit_kwargs ,
392- ) -> Any :
393- if "eval_set" in fit_kwargs :
394- fit_kwargs ["eval_set" ] = [
395- (
396- nw .from_native (X_val ).select (self .selected_features_ ).to_native (),
397- y_val ,
398- )
399- for X_val , y_val in fit_kwargs ["eval_set" ]
400- ]
371+ def transform (self , X : nwt .IntoFrameT ) -> nwt .IntoFrameT :
372+ return nw .from_native (X ).select (self .selected_features_ ).to_native ()
373+
374+ def fit_transform (
375+ self , X : nwt .IntoDataFrameT , y : Any , ** fit_kwargs
376+ ) -> nwt .IntoDataFrameT :
377+ return self .fit (X , y , ** fit_kwargs ).transform (X )
378+
379+
380+ class CFE :
381+ def __init__ (self , threshold : float = 0.99 , seed : Optional [int ] = 208 ):
382+ self .threshold = threshold
383+ self .seed = seed
384+
385+ @staticmethod
386+ def _find_drop (corr_mat : nw .DataFrame , seed : Optional [int ]) -> tuple [str , int ]:
387+ f1_counts = corr_mat .group_by ("f1" ).agg (nw .len ().alias ("count_f1" ))
388+ f2_counts = corr_mat .group_by ("f2" ).agg (nw .len ().alias ("count_f2" ))
389+
390+ counts = (
391+ f1_counts .join (f2_counts , left_on = "f1" , right_on = "f2" , how = "full" )
392+ .with_columns (
393+ nw .coalesce ("f1" , "f2" ).alias ("feature" ),
394+ nw .sum_horizontal ("count_f1" , "count_f2" ).alias ("count" ),
395+ )
396+ .select ("feature" , "count" )
397+ .filter (nw .col ("count" ).__eq__ (nw .col ("count" ).max ()))
398+ # We need to sort by "feature" because the order after the join is not
399+ # always the same, making multiple runs even with the same seed not
400+ # reproducible without the sort.
401+ .sort ("feature" )
402+ # We could take the first or last, but let's sample so that we don't
403+ # introduce bias based on the alphabetical order.
404+ .sample (1 , seed = seed )
405+ )
406+
407+ return (counts ["feature" ].item (), counts ["count" ].item ())
408+
409+ def fit_from_correlation_matrix (
410+ self , corr_mat : nwt .IntoFrame , index : str = "" , transform : bool = True
411+ ):
412+ cm_nw = nw .from_native (corr_mat ).lazy ()
401413
402- return self .unfit_estimator .fit (
403- nw .from_native (X , eager_only = True )
404- .select (self .selected_features_ )
405- .to_native (),
406- y ,
407- ** fit_kwargs ,
414+ if transform :
415+ cm_nw = cm_nw .unpivot (index = index ).rename (
416+ {index : "f1" , "variable" : "f2" , "value" : "correlation" }
417+ )
418+
419+ features = (
420+ nw .concat (
421+ [
422+ cm_nw .select ("f1" ).rename ({"f1" : "x" }),
423+ cm_nw .select ("f2" ).rename ({"f2" : "x" }),
424+ ],
425+ how = "vertical" ,
426+ )
427+ .unique ()
428+ .collect ()["x" ]
429+ .to_list ()
430+ )
431+
432+ cm_nw = (
433+ cm_nw .with_columns (nw .col ("correlation" ).abs ())
434+ .filter (
435+ nw .col ("f1" ).__ne__ (nw .col ("f2" )),
436+ nw .col ("correlation" ).is_null ().__invert__ (),
437+ nw .col ("correlation" ).is_nan ().__invert__ (),
438+ nw .col ("correlation" ).__ge__ (self .threshold ),
439+ )
440+ .collect ()
408441 )
409442
410- def fit_transform (self , X : nwt .IntoDataFrame , y : Any , ** fit_kwargs ) -> Any :
411- return self .fit (X , y , ** fit_kwargs ).transform (X , y , ** fit_kwargs )
443+ drop_list = []
444+ i = 0
445+ while cm_nw .shape [0 ] > 0 :
446+ to_drop , count = self ._find_drop (cm_nw , self .seed )
447+
448+ logger .info (
449+ f"Iteration { i } : Dropping { to_drop } , correlated with { count } other features"
450+ )
451+
452+ cm_nw = cm_nw .filter (
453+ nw .col ("f1" )
454+ .__eq__ (to_drop )
455+ .__or__ (nw .col ("f2" ).__eq__ (to_drop ))
456+ .__invert__ ()
457+ )
458+
459+ drop_list .append (to_drop )
460+ i += 1
461+
462+ self .selected_features_ = sorted (list (set (features ) - set (drop_list )))
463+
464+ return self
465+
466+ def fit (self , X : nwt .IntoFrame ):
467+ corr_mat = correlation_matrix (X )
468+
469+ self .fit_from_correlation_matrix (corr_mat )
470+
471+ return self
472+
473+ def transform (self , X : nwt .IntoFrameT ) -> nwt .IntoFrameT :
474+ return nw .from_native (X ).select (self .selected_features_ ).to_native ()
475+
476+ def fit_transform (self , X : nwt .IntoFrameT ) -> nwt .IntoFrameT :
477+ return self .fit (X ).transform (X )
0 commit comments