@@ -582,14 +582,17 @@ def __init__(self, threshold: float = 0.99, seed: Optional[int] = 208):
582582
583583 @staticmethod
584584 def _find_drop (corr_mat : nw .DataFrame , seed : Optional [int ]) -> tuple [str , int ]:
585- f1_counts = corr_mat .group_by ("f1" ).agg (nw .len ().alias ("count_f1" ))
586- f2_counts = corr_mat .group_by ("f2" ).agg (nw .len ().alias ("count_f2" ))
585+ c1 = "c1"
586+ c2 = "c2"
587+
588+ c1_counts = corr_mat .group_by (c1 ).agg (nw .len ().alias ("count_c1" ))
589+ c2_counts = corr_mat .group_by (c2 ).agg (nw .len ().alias ("count_c2" ))
587590
588591 counts = (
589- f1_counts .join (f2_counts , left_on = "f1" , right_on = "f2" , how = "full" )
592+ c1_counts .join (c2_counts , left_on = c1 , right_on = c2 , how = "full" )
590593 .with_columns (
591- nw .coalesce ("f1" , "f2" ).alias ("feature" ),
592- nw .sum_horizontal ("count_f1 " , "count_f2 " ).alias ("count" ),
594+ nw .coalesce (c1 , c2 ).alias ("feature" ),
595+ nw .sum_horizontal ("count_c1 " , "count_c2 " ).alias ("count" ),
593596 )
594597 .select ("feature" , "count" )
595598 .filter (nw .col ("count" ).__eq__ (nw .col ("count" ).max ()))
@@ -619,24 +622,26 @@ def fit_from_correlation_matrix(
619622 Whether to transfrom the correlation matrix to long form. A wide form
620623 correlation matrix has columns that are features and an "index" that lists
621624 the features. If False, the correlation matrix must already be in long form
622- with at least 3 columns, "f1 ", "f2 ", and "correlation" , by default True
625+ with at least 3 columns, "c1 ", "c2 ", and "correlation" , by default True
623626
624627 Returns
625628 -------
626629 Self
627630 """
631+ c1 = "c1"
632+ c2 = "c2"
628633 cm_nw = nw .from_native (corr_mat ).lazy ()
629634
630635 if transform :
631636 cm_nw = cm_nw .unpivot (index = index ).rename (
632- {index : "f1" , "variable" : "f2" , "value" : "correlation" }
637+ {index : c1 , "variable" : c2 , "value" : "correlation" }
633638 )
634639
635640 features = (
636641 nw .concat (
637642 [
638- cm_nw .select ("f1" ).rename ({"f1 " : "x" }),
639- cm_nw .select ("f2" ).rename ({"f2 " : "x" }),
643+ cm_nw .select (c1 ).rename ({"c1 " : "x" }),
644+ cm_nw .select (c2 ).rename ({"c2 " : "x" }),
640645 ],
641646 how = "vertical" ,
642647 )
@@ -648,7 +653,7 @@ def fit_from_correlation_matrix(
648653 cm_nw = (
649654 cm_nw .with_columns (nw .col ("correlation" ).abs ())
650655 .filter (
651- nw .col ("f1" ).__ne__ (nw .col ("f2" )),
656+ nw .col (c1 ).__ne__ (nw .col (c2 )),
652657 nw .col ("correlation" ).is_null ().__invert__ (),
653658 nw .col ("correlation" ).is_nan ().__invert__ (),
654659 nw .col ("correlation" ).__ge__ (self .threshold ),
@@ -666,9 +671,9 @@ def fit_from_correlation_matrix(
666671 )
667672
668673 cm_nw = cm_nw .filter (
669- nw .col ("f1" )
674+ nw .col (c1 )
670675 .__eq__ (to_drop )
671- .__or__ (nw .col ("f2" ).__eq__ (to_drop ))
676+ .__or__ (nw .col (c2 ).__eq__ (to_drop ))
672677 .__invert__ ()
673678 )
674679
0 commit comments