Skip to content

Commit 815b1ad

Browse files
committed
feat: use c1 and c2 consistently
1 parent 2a1c0f2 commit 815b1ad

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

python/rapidstats/selection.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -582,14 +582,17 @@ def __init__(self, threshold: float = 0.99, seed: Optional[int] = 208):
582582

583583
@staticmethod
584584
def _find_drop(corr_mat: nw.DataFrame, seed: Optional[int]) -> tuple[str, int]:
585-
f1_counts = corr_mat.group_by("f1").agg(nw.len().alias("count_f1"))
586-
f2_counts = corr_mat.group_by("f2").agg(nw.len().alias("count_f2"))
585+
c1 = "c1"
586+
c2 = "c2"
587+
588+
c1_counts = corr_mat.group_by(c1).agg(nw.len().alias("count_c1"))
589+
c2_counts = corr_mat.group_by(c2).agg(nw.len().alias("count_c2"))
587590

588591
counts = (
589-
f1_counts.join(f2_counts, left_on="f1", right_on="f2", how="full")
592+
c1_counts.join(c2_counts, left_on=c1, right_on=c2, how="full")
590593
.with_columns(
591-
nw.coalesce("f1", "f2").alias("feature"),
592-
nw.sum_horizontal("count_f1", "count_f2").alias("count"),
594+
nw.coalesce(c1, c2).alias("feature"),
595+
nw.sum_horizontal("count_c1", "count_c2").alias("count"),
593596
)
594597
.select("feature", "count")
595598
.filter(nw.col("count").__eq__(nw.col("count").max()))
@@ -619,24 +622,26 @@ def fit_from_correlation_matrix(
619622
Whether to transfrom the correlation matrix to long form. A wide form
620623
correlation matrix has columns that are features and an "index" that lists
621624
the features. If False, the correlation matrix must already be in long form
622-
with at least 3 columns, "f1", "f2", and "correlation" , by default True
625+
with at least 3 columns, "c1", "c2", and "correlation" , by default True
623626
624627
Returns
625628
-------
626629
Self
627630
"""
631+
c1 = "c1"
632+
c2 = "c2"
628633
cm_nw = nw.from_native(corr_mat).lazy()
629634

630635
if transform:
631636
cm_nw = cm_nw.unpivot(index=index).rename(
632-
{index: "f1", "variable": "f2", "value": "correlation"}
637+
{index: c1, "variable": c2, "value": "correlation"}
633638
)
634639

635640
features = (
636641
nw.concat(
637642
[
638-
cm_nw.select("f1").rename({"f1": "x"}),
639-
cm_nw.select("f2").rename({"f2": "x"}),
643+
cm_nw.select(c1).rename({"c1": "x"}),
644+
cm_nw.select(c2).rename({"c2": "x"}),
640645
],
641646
how="vertical",
642647
)
@@ -648,7 +653,7 @@ def fit_from_correlation_matrix(
648653
cm_nw = (
649654
cm_nw.with_columns(nw.col("correlation").abs())
650655
.filter(
651-
nw.col("f1").__ne__(nw.col("f2")),
656+
nw.col(c1).__ne__(nw.col(c2)),
652657
nw.col("correlation").is_null().__invert__(),
653658
nw.col("correlation").is_nan().__invert__(),
654659
nw.col("correlation").__ge__(self.threshold),
@@ -666,9 +671,9 @@ def fit_from_correlation_matrix(
666671
)
667672

668673
cm_nw = cm_nw.filter(
669-
nw.col("f1")
674+
nw.col(c1)
670675
.__eq__(to_drop)
671-
.__or__(nw.col("f2").__eq__(to_drop))
676+
.__or__(nw.col(c2).__eq__(to_drop))
672677
.__invert__()
673678
)
674679

tests/test_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def test_cfe():
148148
assert cfe.fit_from_correlation_matrix(corr_mat).selected_features_ == expected
149149

150150
corr_mat_unpivoted = corr_mat.unpivot(index="").rename(
151-
{"": "f1", "variable": "f2", "value": "correlation"}
151+
{"": "c1", "variable": "c2", "value": "correlation"}
152152
)
153153

154154
assert (

0 commit comments

Comments
 (0)