Skip to content

Commit b2ba0a7

Browse files
authored
Merge pull request #16 from CangyuanLi/selection
2 parents 5319f7b + 41d19a0 commit b2ba0a7

File tree

3 files changed

+227
-63
lines changed

3 files changed

+227
-63
lines changed

python/rapidstats/_corr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import itertools
22
from typing import Literal, Optional, Union
33

4-
import narwhals as nw
5-
import narwhals.typing as nwt
4+
import narwhals.stable.v1 as nw
5+
import narwhals.stable.v1.typing as nwt
66
import polars as pl
77

88
CorrelationMethod = Literal["pearson", "spearman"]

python/rapidstats/selection.py

Lines changed: 127 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
import inspect
5+
import logging
56
import math
67
import pickle
78
from collections.abc import Iterable
@@ -14,8 +15,11 @@
1415
from polars.series.series import ArrayLike
1516
from tqdm.auto import tqdm
1617

18+
from ._corr import correlation_matrix
1719
from .metrics import roc_auc
1820

21+
logger = logging.getLogger(__name__)
22+
1923

2024
class Estimator(Protocol):
2125
def fit(self, X, y, **kwargs): ...
@@ -179,7 +183,7 @@ def __init__(
179183
n_features_to_select: float = 1,
180184
step: float = 1,
181185
importance: Callable[[RFEState], Iterable[float]] = _rfe_get_feature_importance,
182-
callbacks: Optional[Iterable[Callable[[RFEState]]]] = None,
186+
callbacks: Optional[Iterable[Callable[[RFEState], Any]]] = None,
183187
quiet: bool = False,
184188
):
185189
self.unfit_estimator = estimator
@@ -235,14 +239,14 @@ def fit(
235239
**fit_kwargs,
236240
)
237241

238-
state = {
239-
"estimator": est,
240-
"X": X_loop,
241-
"y": y,
242-
"eval_set": fit_kwargs.get("eval_set", None),
243-
"features": features,
244-
"iteration": iteration,
245-
}
242+
state = RFEState(
243+
estimator=est,
244+
X=X_loop,
245+
y=y,
246+
eval_set=fit_kwargs.get("eval_set", None),
247+
features=features,
248+
iteration=iteration,
249+
)
246250

247251
for callback in self.callbacks:
248252
callback(state)
@@ -261,6 +265,9 @@ def fit(
261265
real_step = _get_step(len_features, step)
262266
k = len_features - real_step
263267

268+
if k <= 0:
269+
break
270+
264271
remaining_features = (
265272
pl.LazyFrame(
266273
{"importance": self.importance(state), "feature": features}
@@ -277,38 +284,15 @@ def fit(
277284
pbar.update(1)
278285

279286
self.estimator_ = est
280-
self.selected_features_ = features
287+
self.selected_features_ = sorted(features)
281288

282289
return self
283290

284-
def transform(
285-
self,
286-
X: Optional[nwt.IntoDataFrame] = None,
287-
y: Optional[Any] = None,
288-
**fit_kwargs,
289-
) -> Any:
290-
if X is None or y is None:
291-
return self.estimator_
292-
293-
if "eval_set" in fit_kwargs:
294-
fit_kwargs["eval_set"] = [
295-
(
296-
nw.from_native(X_val).select(self.selected_features_).to_native(),
297-
y_val,
298-
)
299-
for X_val, y_val in fit_kwargs["eval_set"]
300-
]
301-
302-
return self.unfit_estimator.fit(
303-
nw.from_native(X, eager_only=True)
304-
.select(self.selected_features_)
305-
.to_native(),
306-
y,
307-
**fit_kwargs,
308-
)
291+
def transform(self, X: nwt.IntoFrameT) -> nwt.IntoFrameT:
292+
return nw.from_native(X).select(self.selected_features_).to_native()
309293

310294
def fit_transform(self, X, y, **fit_kwargs) -> Any:
311-
return self.fit(X, y, **fit_kwargs).transform()
295+
return self.fit(X, y, **fit_kwargs).transform(X)
312296

313297

314298
class NFEState(TypedDict):
@@ -328,7 +312,7 @@ def __init__(
328312
self,
329313
estimator: Estimator,
330314
importance: Callable[[NFEState], ArrayLike] = _nfe_get_feature_importance,
331-
seed: Optional[int] = None,
315+
seed: Optional[int] = 208,
332316
):
333317
self.unfit_estimator = estimator
334318
self.importance = importance
@@ -347,7 +331,6 @@ def _add_noise(self, df: nw.DataFrame) -> nw.DataFrame:
347331
)
348332

349333
def fit(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs):
350-
351334
X_nw = nw.from_native(X, eager_only=True).pipe(self._add_noise)
352335

353336
if "eval_set" in fit_kwargs:
@@ -364,7 +347,7 @@ def fit(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs):
364347
X_train = X_nw.to_native()
365348
est = self.unfit_estimator.fit(X_train, y, **fit_kwargs)
366349

367-
state = {"estimator": est, "X": X_train, "y": y}
350+
state = NFEState(estimator=est, X=X_train, y=y)
368351

369352
nfe_features = (
370353
pl.LazyFrame(
@@ -377,35 +360,118 @@ def fit(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs):
377360
)
378361
)
379362
.collect()["feature"]
363+
.sort()
380364
.to_list()
381365
)
382366

383367
self.selected_features_ = nfe_features
384368

385369
return self
386370

387-
def transform(
388-
self,
389-
X: nwt.IntoDataFrame,
390-
y: Any,
391-
**fit_kwargs,
392-
) -> Any:
393-
if "eval_set" in fit_kwargs:
394-
fit_kwargs["eval_set"] = [
395-
(
396-
nw.from_native(X_val).select(self.selected_features_).to_native(),
397-
y_val,
398-
)
399-
for X_val, y_val in fit_kwargs["eval_set"]
400-
]
371+
def transform(self, X: nwt.IntoFrameT) -> nwt.IntoFrameT:
372+
return nw.from_native(X).select(self.selected_features_).to_native()
373+
374+
def fit_transform(
375+
self, X: nwt.IntoDataFrameT, y: Any, **fit_kwargs
376+
) -> nwt.IntoDataFrameT:
377+
return self.fit(X, y, **fit_kwargs).transform(X)
378+
379+
380+
class CFE:
381+
def __init__(self, threshold: float = 0.99, seed: Optional[int] = 208):
382+
self.threshold = threshold
383+
self.seed = seed
384+
385+
@staticmethod
386+
def _find_drop(corr_mat: nw.DataFrame, seed: Optional[int]) -> tuple[str, int]:
387+
f1_counts = corr_mat.group_by("f1").agg(nw.len().alias("count_f1"))
388+
f2_counts = corr_mat.group_by("f2").agg(nw.len().alias("count_f2"))
389+
390+
counts = (
391+
f1_counts.join(f2_counts, left_on="f1", right_on="f2", how="full")
392+
.with_columns(
393+
nw.coalesce("f1", "f2").alias("feature"),
394+
nw.sum_horizontal("count_f1", "count_f2").alias("count"),
395+
)
396+
.select("feature", "count")
397+
.filter(nw.col("count").__eq__(nw.col("count").max()))
398+
# We need to sort by "feature" because the order after the join is not
399+
# always the same, making multiple runs even with the same seed not
400+
# reproducible without the sort.
401+
.sort("feature")
402+
# We could take the first or last, but let's sample so that we don't
403+
# introduce bias based on the alphabetical order.
404+
.sample(1, seed=seed)
405+
)
406+
407+
return (counts["feature"].item(), counts["count"].item())
408+
409+
def fit_from_correlation_matrix(
410+
self, corr_mat: nwt.IntoFrame, index: str = "", transform: bool = True
411+
):
412+
cm_nw = nw.from_native(corr_mat).lazy()
401413

402-
return self.unfit_estimator.fit(
403-
nw.from_native(X, eager_only=True)
404-
.select(self.selected_features_)
405-
.to_native(),
406-
y,
407-
**fit_kwargs,
414+
if transform:
415+
cm_nw = cm_nw.unpivot(index=index).rename(
416+
{index: "f1", "variable": "f2", "value": "correlation"}
417+
)
418+
419+
features = (
420+
nw.concat(
421+
[
422+
cm_nw.select("f1").rename({"f1": "x"}),
423+
cm_nw.select("f2").rename({"f2": "x"}),
424+
],
425+
how="vertical",
426+
)
427+
.unique()
428+
.collect()["x"]
429+
.to_list()
430+
)
431+
432+
cm_nw = (
433+
cm_nw.with_columns(nw.col("correlation").abs())
434+
.filter(
435+
nw.col("f1").__ne__(nw.col("f2")),
436+
nw.col("correlation").is_null().__invert__(),
437+
nw.col("correlation").is_nan().__invert__(),
438+
nw.col("correlation").__ge__(self.threshold),
439+
)
440+
.collect()
408441
)
409442

410-
def fit_transform(self, X: nwt.IntoDataFrame, y: Any, **fit_kwargs) -> Any:
411-
return self.fit(X, y, **fit_kwargs).transform(X, y, **fit_kwargs)
443+
drop_list = []
444+
i = 0
445+
while cm_nw.shape[0] > 0:
446+
to_drop, count = self._find_drop(cm_nw, self.seed)
447+
448+
logger.info(
449+
f"Iteration {i}: Dropping {to_drop}, correlated with {count} other features"
450+
)
451+
452+
cm_nw = cm_nw.filter(
453+
nw.col("f1")
454+
.__eq__(to_drop)
455+
.__or__(nw.col("f2").__eq__(to_drop))
456+
.__invert__()
457+
)
458+
459+
drop_list.append(to_drop)
460+
i += 1
461+
462+
self.selected_features_ = sorted(list(set(features) - set(drop_list)))
463+
464+
return self
465+
466+
def fit(self, X: nwt.IntoFrame):
467+
corr_mat = correlation_matrix(X)
468+
469+
self.fit_from_correlation_matrix(corr_mat)
470+
471+
return self
472+
473+
def transform(self, X: nwt.IntoFrameT) -> nwt.IntoFrameT:
474+
return nw.from_native(X).select(self.selected_features_).to_native()
475+
476+
def fit_transform(self, X: nwt.IntoFrameT) -> nwt.IntoFrameT:
477+
return self.fit(X).transform(X)

tests/test_selection.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,30 @@ def test_rfe(estimator):
9494
)
9595

9696
assert rfe.selected_features_ == ["f0.99"]
97+
assert rfe.transform(X).columns == ["f0.99"]
98+
99+
100+
@pytest.mark.parametrize("estimator", ESTIMATORS)
101+
def test_rfe_early_stopping(estimator):
102+
fit_kwargs = {}
103+
if "eval_set" in inspect.signature(estimator.fit).parameters:
104+
fit_kwargs["eval_set"] = [(X, y)]
105+
106+
def _roc_auc(est, X, y) -> float:
107+
return rs.metrics.roc_auc(y, est.predict(X))
108+
109+
early_stopping_kwargs = {}
110+
if "predict_proba" not in inspect.getmembers(
111+
estimator, predicate=inspect.isfunction
112+
):
113+
early_stopping_kwargs["metric"] = _roc_auc
114+
115+
rs.selection.RFE(
116+
estimator=estimator,
117+
step=3,
118+
quiet=True,
119+
callbacks=[rs.selection.EarlyStopping(**early_stopping_kwargs)],
120+
).fit(X, y, **fit_kwargs)
97121

98122

99123
@pytest.mark.parametrize("estimator", ESTIMATORS)
@@ -105,3 +129,77 @@ def test_nfe(estimator):
105129
nfe = rs.selection.NFE(estimator=estimator, seed=SEED).fit(X, y, **fit_kwargs)
106130

107131
assert "f0.99" in nfe.selected_features_
132+
assert "f0.99" in nfe.transform(X).columns
133+
134+
135+
def test_cfe():
136+
corr_mat = pl.DataFrame(
137+
{
138+
"": ["a", "b", "c"],
139+
"a": [1.0, 0.5, 0.7],
140+
"b": [-0.99, 1, 0.98],
141+
"c": [float("nan"), None, 1],
142+
}
143+
)
144+
145+
expected = ["a", "c"]
146+
cfe = rs.selection.CFE(threshold=0.95)
147+
148+
assert cfe.fit_from_correlation_matrix(corr_mat).selected_features_ == expected
149+
150+
corr_mat_unpivoted = corr_mat.unpivot(index="").rename(
151+
{"": "f1", "variable": "f2", "value": "correlation"}
152+
)
153+
154+
assert (
155+
cfe.fit_from_correlation_matrix(
156+
corr_mat_unpivoted, transform=False
157+
).selected_features_
158+
== expected
159+
)
160+
161+
162+
def test_cfe_identity_no_drop():
163+
# Test that identity correlations do not cause a feature to be removed, i.e.
164+
# corr(a, a) = 1 should not cause feature a to be dropped.
165+
corr_mat = pl.DataFrame(
166+
{
167+
"": ["a", "b", "c"],
168+
"a": [1.0, 0.5, 0.7],
169+
"b": [0.5, 1.0, 0.98],
170+
"c": [float("nan"), None, 1],
171+
}
172+
)
173+
174+
assert rs.selection.CFE(threshold=0.99).fit_from_correlation_matrix(
175+
corr_mat
176+
).selected_features_ == ["a", "b", "c"]
177+
178+
179+
def test_cfe_corr_1_is_removed():
180+
# Test that a correlation of 1 that is not an identity causes a feature to be
181+
# correctly removed.
182+
corr_mat = pl.DataFrame(
183+
{
184+
"": ["a", "b", "c"],
185+
"a": [1.0, 0.5, 0.7],
186+
"b": [0.5, 1.0, 1.0],
187+
"c": [float("nan"), None, 1],
188+
}
189+
)
190+
191+
assert rs.selection.CFE(threshold=0.99).fit_from_correlation_matrix(
192+
corr_mat
193+
).selected_features_ == ["a", "c"]
194+
195+
196+
def test_cfe_repro():
197+
n_cols = 50
198+
df = pl.DataFrame(
199+
np.random.rand(1_000, n_cols), schema=[f"col_{i}" for i in range(n_cols)]
200+
)
201+
202+
assert (
203+
rs.selection.CFE().fit(df).selected_features_
204+
== rs.selection.CFE().fit(df).selected_features_
205+
)

0 commit comments

Comments
 (0)