Skip to content

Commit 2c79259

Browse files
committed
fix bug, add support for serialization and add more tests
1 parent eed3af1 commit 2c79259

File tree

6 files changed

+221
-50
lines changed

6 files changed

+221
-50
lines changed

Project.toml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,36 @@ Aqua = "0.8"
1313
Distributions = "0.25"
1414
julia = "1.6"
1515
MLJBase = "1.1"
16+
MLJTuning = "0.8"
1617
MLJDecisionTreeInterface = "0.4"
18+
MLJScikitLearnInterface = "0.6"
1719
MLJModelInterface = "1.4"
1820
ScientificTypesBase = "3"
1921
StableRNGs = "1"
22+
StatisticalMeasures = "0.1"
2023
Tables = "1.2"
2124
Test = "1.6"
2225

2326
[extras]
2427
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2528
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
2629
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
30+
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
2731
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
32+
MLJScikitLearnInterface = "5ae90465-5518-4432-b9d2-8a1def2f0cab"
2833
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
34+
StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541"
2935
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3036

3137
[targets]
32-
test = ["Aqua", "Distributions", "MLJBase", "MLJDecisionTreeInterface", "StableRNGs", "Test"]
38+
test = [
39+
"Aqua",
40+
"Distributions",
41+
"MLJBase",
42+
"MLJTuning",
43+
"MLJDecisionTreeInterface",
44+
"MLJScikitLearnInterface",
45+
"StableRNGs",
46+
"StatisticalMeasures",
47+
"Test"
48+
]

README.md

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# FeatureSelection.jl
22

3-
| Linux | Coverage |
4-
| :------------ | :------- |
5-
| [![Build Status](https://github.com/JuliaAI/FeatureSelection.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/FeatureSelection.jl/actions) | [![Coverage](https://codecov.io/gh/JuliaAI/FeatureSelection.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/FeatureSelection.jl?branch=dev) |
3+
| Linux | Coverage | Code Style
4+
| :------------ | :------- | :------------- |
5+
| [![Build Status](https://github.com/JuliaAI/FeatureSelection.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/FeatureSelection.jl/actions) | [![Coverage](https://codecov.io/gh/JuliaAI/FeatureSelection.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/FeatureSelection.jl?branch=dev) | [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) |
66

77
Repository housing feature selection algorithms for use with the machine learning toolbox
88
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/).
@@ -26,20 +26,20 @@ recursive feature elimination should return the first columns as important featu
2626
```julia
2727
using MLJ, FeatureSelection
2828
using StableRNGs
29-
rng = StableRNG(123)
29+
rng = StableRNG(10)
3030
A = rand(rng, 50, 10)
3131
X = MLJ.table(A) # features
3232
y = @views(
3333
10 .* sin.(
3434
pi .* A[:, 1] .* A[:, 2]
35-
) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]
35+
) .+ 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]
3636
) # target
3737
```
3838
Now we that we have our data we can create our recursive feature elimination model and
3939
train it on our dataset
4040
```julia
4141
RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
42-
forest = RandomForestRegressor()
42+
forest = RandomForestRegressor(rng=rng)
4343
rfe = RecursiveFeatureElimination(
4444
model = forest, n_features=5, step=1
4545
) # see doctring for description of defaults
@@ -48,24 +48,28 @@ fit!(mach)
4848
```
4949
We can inspect the feature importances in two ways:
5050
```julia
51+
# A variable with lower rank has more significance than a variable with higher rank.
52+
# A variable with Higher feature importance is better than a variable with lower
53+
# feature importance
5154
report(mach).ranking # returns [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
52-
feature_importances(mach) # returns dict of feature => rank pairs
55+
feature_importances(mach) # returns dict of feature => importance pairs
5356
```
54-
We can view the important features used by our model by inspecting the `fitted_params` object.
57+
We can view the important features used by our model by inspecting the `fitted_params`
58+
object.
5559
```julia
5660
p = fitted_params(mach)
5761
p.features_left == [:x1, :x2, :x3, :x4, :x5]
5862
```
5963
We can also call the `predict` method on the fitted machine, to predict using a
6064
random forest regressor trained using only the important features, or call the `transform`
61-
method, to select just those features from some new table including all the original features.
62-
For more info, type `?RecursiveFeatureElimination` on a Julia REPL.
65+
method, to select just those features from some new table including all the original
66+
features. For more info, type `?RecursiveFeatureElimination` on a Julia REPL.
6367

6468
Okay, let's say that we didn't know that our synthetic dataset depends on only five
65-
columns from our feature table. We could apply cross fold validation `CV(nfolds=5)` with
66-
our recursive feature elimination model to select the optimal value of
67-
`n_features` for our model. In this case we will use a simple Grid search with root mean
68-
square as the measure.
69+
columns from our feature table. We could apply cross fold validation
70+
`StratifiedCV(nfolds=5)` with our recursive feature elimination model to select the
71+
optimal value of `n_features` for our model. In this case we will use a simple Grid
72+
search with root mean square as the measure.
6973
```julia
7074
rfe = RecursiveFeatureElimination(model = forest)
7175
tuning_rfe_model = TunedModel(
@@ -74,15 +78,17 @@ tuning_rfe_model = TunedModel(
7478
tuning = Grid(rng=rng),
7579
resampling = StratifiedCV(nfolds = 5),
7680
range = range(
77-
rfe, :n_features, lower = 1, upper=10, unit=1
81+
rfe, :n_features, values = 1:10
7882
)
7983
)
8084
self_tuning_rfe_mach = machine(tuning_rfe_model, X, y)
8185
fit!(self_tuning_rfe_mach)
8286
```
83-
As before we can inspect the important features by inspecting the `fitted_params` object.
87+
As before we can inspect the important features by inspecting the object returned by
88+
`fitted_params` or `feature_importances` as shown below.
8489
```julia
8590
fitted_params(self_tuning_rfe_mach).best_fitted_params.features_left == [:x1, :x2, :x3, :x4, :x5]
91+
feature_importances(self_tuning_rfe_mach) # returns dict of feature => importance pairs
8692
```
8793
and call `predict` on the tuned model machine as shown below
8894
```julia

src/models/rfe.jl

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -202,34 +202,35 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
202202
)
203203

204204
# Compute required number of features to select
205-
n_features = selector.n_features # Remember to modify this estimate later
205+
n_features_select = selector.n_features
206206
## zero indicates that half of the features be selected.
207-
if n_features == 0
208-
n_features = div(nfeatures, 2)
209-
elseif 0 < n_features < 1
210-
n_features = round(Int, n_features * n_features)
207+
if n_features_select == 0
208+
n_features_select = div(nfeatures, 2)
209+
elseif 0 < n_features_select < 1
210+
n_features_select = round(Int, n_features_select * nfeatures)
211211
else
212-
n_features = round(Int, n_features)
212+
n_features_select = round(Int, n_features_select)
213213
end
214214

215215
step = selector.step
216216

217217
if 0 < step < 1
218-
step = round(Int, max(1, step * n_features))
218+
step = round(Int, max(1, step * n_features_select))
219219
else
220220
step = round(Int, step)
221221
end
222222

223223
support = trues(nfeatures)
224-
ranking = ones(nfeatures) # every feature has equal rank initially
225-
indexes = axes(support, 1)
224+
ranking = ones(Int, nfeatures) # every feature has equal rank initially
225+
mask = trues(nfeatures) # for boolean indexing of ranking vector in while loop below
226226

227227
# Elimination
228-
features_left = copy(features)
229-
while sum(support) > n_features
228+
features_left = features
229+
n_features_left = length(features_left)
230+
while n_features_left > n_features_select
230231
# Rank the remaining features
231232
model = selector.model
232-
verbosity > 0 && @info("Fitting estimator with $(sum(support)) features.")
233+
verbosity > 0 && @info("Fitting estimator with $(n_features_left) features.")
233234

234235
data = MMI.reformat(model, MMI.selectcols(X, features_left), args...)
235236

@@ -249,24 +250,25 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
249250
ranks = sortperm(importances)
250251

251252
# Eliminate the worse features
252-
threshold = min(step, sum(support) - n_features)
253-
254-
support[indexes[ranks][1:threshold]] .= false
255-
ranking[.!support] .+= 1
253+
threshold = min(step, n_features_left - n_features_select)
254+
@views(support[support][ranks[1:threshold]]) .= false
255+
mask .= support .== false
256+
@views(ranking[mask]) .+= 1
256257

257258
# Remaining features
258-
features_left = @view(features[support])
259+
features_left = features[support]
260+
n_features_left = length(features_left)
259261
end
260262

261263
# Set final attributes
262264
data = MMI.reformat(selector.model, MMI.selectcols(X, features_left), args...)
263-
verbosity > 0 && @info ("Fitting estimator with $(sum(support)) features.")
265+
verbosity > 0 && @info ("Fitting estimator with $(n_features_left) features.")
264266
model_fitresult, _, model_report = MMI.fit(selector.model, verbosity - 1, data...)
265267

266268
fitresult = (
267269
support = support,
268270
model_fitresult = model_fitresult,
269-
features_left = copy(features_left),
271+
features_left = features_left,
270272
features = features
271273
)
272274
report = (
@@ -280,7 +282,7 @@ end
280282

281283
function MMI.fitted_params(model::RFE, fitresult)
282284
(
283-
features_left = fitresult.features_left,
285+
features_left = copy(fitresult.features_left),
284286
model_fitresult = MMI.fitted_params(model.model, fitresult.model_fitresult)
285287
)
286288
end
@@ -295,15 +297,45 @@ function MMI.transform(::RFE, fitresult, X)
295297
sch = Tables.schema(Tables.columns(X))
296298
if (length(fitresult.features) == length(sch.names) &&
297299
!all(e -> e in sch.names, fitresult.features))
298-
throw(
299-
ERR_FEATURES_SEEN
300-
)
300+
throw(
301+
ERR_FEATURES_SEEN
302+
)
301303
end
302304
return MMI.selectcols(X, fitresult.features_left)
303305
end
304306

305307
function MMI.feature_importances(::RFE, fitresult, report)
306-
return Pair.(fitresult.features, report.ranking)
308+
return Pair.(fitresult.features, Iterators.reverse(report.ranking))
309+
end
310+
311+
function MMI.save(model::RFE, fitresult)
312+
support = fitresult.support
313+
atomic_fitresult = fitresult.model_fitresult
314+
features_left = fitresult.features_left
315+
features = fitresult.features
316+
317+
atom = model.model
318+
return (
319+
support = copy(support),
320+
model_fitresult = MMI.save(atom, atomic_fitresult),
321+
features_left = copy(features_left),
322+
features = copy(features)
323+
)
324+
end
325+
326+
function MMI.restore(model::RFE, serializable_fitresult)
327+
support = serializable_fitresult.support
328+
atomic_serializable_fitresult = serializable_fitresult.model_fitresult
329+
features_left = serializable_fitresult.features_left
330+
features = serializable_fitresult.features
331+
332+
atom = model.model
333+
return (
334+
support = support,
335+
model_fitresult = MMI.restore(atom, atomic_serializable_fitresult),
336+
features_left = features_left,
337+
features = features
338+
)
307339
end
308340

309341
## Traits definitions

test/models/dummy_test_models.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DummyTestModels
22

33
using MLJBase
4+
using Distributions
45

56
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
67
##
@@ -17,9 +18,50 @@ end
1718
MLJBase.reformat(::DeterministicConstantRegressor, X) = (MLJBase.matrix(X),)
1819
MLJBase.reformat(::DeterministicConstantRegressor, X, y) = (MLJBase.matrix(X), y)
1920
MLJBase.selectrows(::DeterministicConstantRegressor, I, A) = (view(A, I, :),)
20-
MLJBase.selectrows(::DeterministicConstantRegressor, I, A, y) =
21-
(view(A, I, :), y[I])
21+
function MLJBase.selectrows(::DeterministicConstantRegressor, I, A, y)
22+
return (view(A, I, :), y[I])
23+
end
24+
25+
function MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew)
26+
return fill(fitresult, nrows(Xnew))
27+
end
28+
29+
## THE EphemeralClassifier (FOR TESTING)
30+
## Define a Deterministic Classifier with non-persistent `fitresult`, but which addresses
31+
## this by overloading `save`/`restore`:
32+
struct EphemeralClassifier <: MLJBase.Deterministic end
33+
thing = []
34+
35+
function MLJBase.fit(::EphemeralClassifier, verbosity, X, y)
36+
# if I serialize/deserialized `thing` then `id` below changes:
37+
id = objectid(thing)
38+
p = Distributions.fit(UnivariateFinite, y)
39+
fitresult = (thing, id, p)
40+
report = (features = MLJBase.schema(X).names,)
41+
return fitresult, nothing, report
42+
end
43+
44+
function MLJBase.predict(::EphemeralClassifier, fitresult, X)
45+
thing, id, p = fitresult
46+
id == objectid(thing) || throw(ErrorException("dead fitresult"))
47+
return [mode(p) for _ in 1:MLJBase.nrows(X)]
48+
end
49+
50+
function MLJBase.feature_importances(model::EphemeralClassifier, fitresult, report)
51+
return [ftr => 1.0 for ftr in report.features]
52+
end
53+
54+
MLJBase.target_scitype(::Type{<:EphemeralClassifier}) = AbstractVector{OrderedFactor{2}}
55+
MLJBase.reports_feature_importances(::Type{<:EphemeralClassifier}) = true
56+
57+
function MLJBase.save(::EphemeralClassifier, fitresult)
58+
thing, _, p = fitresult
59+
return (thing, p)
60+
end
61+
function MLJBase.restore(::EphemeralClassifier, serialized_fitresult)
62+
thing, p = serialized_fitresult
63+
id = objectid(thing)
64+
return (thing, id, p)
65+
end
2266

23-
MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
24-
fill(fitresult, nrows(Xnew))
2567
end

0 commit comments

Comments
 (0)