Skip to content

Commit 5d605c6

Browse files
authored
Merge pull request #11 from JuliaAI/dev
For a 0.1.1 release
2 parents cfd3576 + ec0eb11 commit 5d605c6

File tree

5 files changed

+83
-71
lines changed

5 files changed

+83
-71
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FeatureSelection"
22
uuid = "33837fe5-dbff-4c9e-8c2f-c5612fe2b8b6"
33
authors = ["Anthony D. Blaom <[email protected]>", "Samuel Okon <[email protected]"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
@@ -12,11 +12,11 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1212
Aqua = "0.8"
1313
Distributions = "0.25"
1414
julia = "1.6"
15-
MLJBase = "1.1"
15+
MLJBase = "1.4"
1616
MLJTuning = "0.8"
1717
MLJDecisionTreeInterface = "0.4"
1818
MLJScikitLearnInterface = "0.6"
19-
MLJModelInterface = "1.4"
19+
MLJModelInterface = "1.10"
2020
ScientificTypesBase = "3"
2121
StableRNGs = "1"
2222
StatisticalMeasures = "0.1"

src/FeatureSelection.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,4 @@ const MMI = MLJModelInterface
1010
include("models/featureselector.jl")
1111
include("models/rfe.jl")
1212

13-
## Pkg Traits
14-
MMI.metadata_pkg.(
15-
(
16-
DeterministicRecursiveFeatureElimination,
17-
ProbabilisticRecursiveFeatureElimination,
18-
FeatureSelector
19-
),
20-
package_name = "FeatureSelection",
21-
package_uuid = "33837fe5-dbff-4c9e-8c2f-c5612fe2b8b6",
22-
package_url = "https://github.com/JuliaAI/FeatureSelection.jl",
23-
is_pure_julia = true,
24-
package_license = "MIT"
25-
)
26-
2713
end # module

src/models/featureselector.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,20 @@ MMI.metadata_model(
8484
FeatureSelector,
8585
input_scitype = Table,
8686
output_scitype = Table,
87-
load_path = "FeatureSelction.FeatureSelector"
87+
load_path = "FeatureSelection.FeatureSelector"
8888
)
8989

90+
## Pkg Traits
91+
MMI.metadata_pkg(
92+
FeatureSelector,
93+
package_name = "FeatureSelection",
94+
package_uuid = "33837fe5-dbff-4c9e-8c2f-c5612fe2b8b6",
95+
package_url = "https://github.com/JuliaAI/FeatureSelection.jl",
96+
is_pure_julia = true,
97+
package_license = "MIT"
98+
)
99+
100+
## Docstring
90101
"""
91102
$(MMI.doc_header(FeatureSelector))
92103
@@ -164,4 +175,4 @@ julia> transform(fit!(machine(selector, X)), X)
164175
165176
```
166177
"""
167-
FeatureSelector
178+
FeatureSelector

src/models/rfe.jl

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function warn_double_spec(arg, model)
22
return "Using `model=$arg`. Ignoring keyword specification `model=$model`. "
33
end
4-
4+
55
const ERR_SPECIFY_MODEL = ArgumentError(
66
"You need to specify model as positional argument or specify `model=...`."
77
)
@@ -36,66 +36,67 @@ for (ModelType, ModelSuperType) in MODELTYPE_GIVEN_SUPERTYPES
3636
eval(ex)
3737
end
3838

39-
eval(:(const RFE{M} = Union{$((Expr(:curly, modeltype, :M) for modeltype in MODEL_TYPES)...)}))
39+
eval(:(const RFE{M} =
40+
Union{$((Expr(:curly, modeltype, :M) for modeltype in MODEL_TYPES)...)}))
4041

4142
# Common keyword constructor for both model types
4243
"""
4344
RecursiveFeatureElimination(model, n_features, step)
4445
45-
This model implements a recursive feature elimination algorithm for feature selection.
46-
It recursively removes features, training a base model on the remaining features and
46+
This model implements a recursive feature elimination algorithm for feature selection.
47+
It recursively removes features, training a base model on the remaining features and
4748
evaluating their importance until the desired number of features is selected.
4849
49-
Construct an instance with default hyper-parameters using the syntax
50-
`model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
51-
hyper-parameter defaults.
52-
50+
Construct an instance with default hyper-parameters using the syntax
51+
`rfe_model = RecursiveFeatureElimination(model=...)`. Provide keyword arguments to override
52+
hyper-parameter defaults.
53+
5354
# Training data
54-
In MLJ or MLJBase, bind an instance `model` to data with
55+
In MLJ or MLJBase, bind an instance `rfe_model` to data with
5556
56-
mach = machine(model, X, y)
57+
mach = machine(rfe_model, X, y)
5758
5859
OR, if the base model supports weights, as
5960
60-
mach = machine(model, X, y, w)
61+
mach = machine(rfe_model, X, y, w)
6162
6263
Here:
6364
6465
- `X` is any table of input features (eg, a `DataFrame`) whose columns are of the scitype
65-
as that required by the base model; check column scitypes with `schema(X)` and column
66+
as that required by the base model; check column scitypes with `schema(X)` and column
6667
scitypes required by base model with `input_scitype(basemodel)`.
6768
68-
- `y` is the target, which can be any table of responses whose element scitype is
69-
`Continuous` or `Finite` depending on the `target_scitype` required by the base model;
69+
- `y` is the target, which can be any table of responses whose element scitype is
70+
`Continuous` or `Finite` depending on the `target_scitype` required by the base model;
7071
check the scitype with `scitype(y)`.
7172
72-
- `w` is the observation weights which can either be `nothing`(default) or an
73-
`AbstractVector` whoose element scitype is `Count` or `Continuous`. This is different
73+
- `w` is the observation weights which can either be `nothing`(default) or an
74+
`AbstractVector` whoose element scitype is `Count` or `Continuous`. This is different
7475
from `weights` kernel which is an hyperparameter to the model, see below.
7576
7677
Train the machine using `fit!(mach, rows=...)`.
7778
7879
# Hyper-parameters
79-
- model: A base model with a `fit` method that provides information on feature
80+
- model: A base model with a `fit` method that provides information on feature
8081
feature importance (i.e `reports_feature_importances(model) == true`)
8182
82-
- n_features::Real = 0: The number of features to select. If `0`, half of the
83-
features are selected. If a positive integer, the parameter is the absolute number
84-
of features to select. If a real number between 0 and 1, it is the fraction of features
83+
- n_features::Real = 0: The number of features to select. If `0`, half of the
84+
features are selected. If a positive integer, the parameter is the absolute number
85+
of features to select. If a real number between 0 and 1, it is the fraction of features
8586
to select.
8687
87-
- step::Real=1: If the value of step is at least 1, it signifies the quantity of features to
88-
eliminate in each iteration. Conversely, if step falls strictly within the range of
88+
- step::Real=1: If the value of step is at least 1, it signifies the quantity of features to
89+
eliminate in each iteration. Conversely, if step falls strictly within the range of
8990
0.0 to 1.0, it denotes the proportion (rounded down) of features to remove during each iteration.
9091
9192
# Operations
9293
93-
- `transform(mach, X)`: transform the input table `X` into a new table containing only
94+
- `transform(mach, X)`: transform the input table `X` into a new table containing only
9495
columns corresponding to features gotten from the RFE algorithm.
9596
96-
- `predict(mach, X)`: transform the input table `X` into a new table same as in
97+
- `predict(mach, X)`: transform the input table `X` into a new table same as in
9798
98-
- `transform(mach, X)` above and predict using the fitted base model on the
99+
- `transform(mach, X)` above and predict using the fitted base model on the
99100
transformed table.
100101
101102
# Fitted parameters
@@ -106,11 +107,11 @@ The fields of `fitted_params(mach)` are:
106107
107108
# Report
108109
The fields of `report(mach)` are:
109-
- `ranking`: The feature ranking of each features in the training dataset.
110+
- `ranking`: The feature ranking of each features in the training dataset.
110111
111112
- `model_report`: report for the fitted base model.
112113
113-
- `features`: names of features seen during the training process.
114+
- `features`: names of features seen during the training process.
114115
115116
# Examples
116117
```
@@ -131,10 +132,10 @@ selector = RecursiveFeatureElimination(model = rf)
131132
mach = machine(selector, X, y)
132133
fit!(mach)
133134
134-
# view the feature importances
135+
# view the feature importances
135136
feature_importances(mach)
136137
137-
# predict using the base model
138+
# predict using the base model
138139
Xnew = MLJ.table(rand(rng, 50, 10));
139140
predict(mach, Xnew)
140141
@@ -160,7 +161,7 @@ function RecursiveFeatureElimination(
160161
#TODO: Check that the specifed model implements the predict method.
161162
# probably add a trait to check this
162163
MMI.reports_feature_importances(model) || throw(ERR_FEATURE_IMPORTANCE_SUPPORT)
163-
if model isa Deterministic
164+
if model isa Deterministic
164165
selector = DeterministicRecursiveFeatureElimination{typeof(model)}(
165166
model, Float64(n_features), Float64(step)
166167
)
@@ -170,7 +171,7 @@ function RecursiveFeatureElimination(
170171
)
171172
else
172173
throw(ERR_MODEL_TYPE)
173-
end
174+
end
174175
message = MMI.clean!(selector)
175176
isempty(message) || @warn(message)
176177
return selector
@@ -204,21 +205,21 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
204205
n_features_select = selector.n_features
205206
## zero indicates that half of the features be selected.
206207
if n_features_select == 0
207-
n_features_select = div(nfeatures, 2)
208+
n_features_select = div(nfeatures, 2)
208209
elseif 0 < n_features_select < 1
209210
n_features_select = round(Int, n_features_select * nfeatures)
210211
else
211212
n_features_select = round(Int, n_features_select)
212213
end
213214

214215
step = selector.step
215-
216+
216217
if 0 < step < 1
217218
step = round(Int, max(1, step * n_features_select))
218219
else
219-
step = round(Int, step)
220+
step = round(Int, step)
220221
end
221-
222+
222223
support = trues(nfeatures)
223224
ranking = ones(Int, nfeatures) # every feature has equal rank initially
224225
mask = trues(nfeatures) # for boolean indexing of ranking vector in while loop below
@@ -230,7 +231,7 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
230231
# Rank the remaining features
231232
model = selector.model
232233
verbosity > 0 && @info("Fitting estimator with $(n_features_left) features.")
233-
234+
234235
data = MMI.reformat(model, MMI.selectcols(X, features_left), args...)
235236

236237
fitresult, _, report = MMI.fit(model, verbosity - 1, data...)
@@ -263,14 +264,14 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
263264
data = MMI.reformat(selector.model, MMI.selectcols(X, features_left), args...)
264265
verbosity > 0 && @info ("Fitting estimator with $(n_features_left) features.")
265266
model_fitresult, _, model_report = MMI.fit(selector.model, verbosity - 1, data...)
266-
267+
267268
fitresult = (
268269
support = support,
269270
model_fitresult = model_fitresult,
270271
features_left = features_left,
271272
features = features
272273
)
273-
report = (
274+
report = (
274275
ranking = ranking,
275276
model_report = model_report
276277
)
@@ -294,7 +295,7 @@ end
294295

295296
function MMI.transform(::RFE, fitresult, X)
296297
sch = Tables.schema(Tables.columns(X))
297-
if (length(fitresult.features) == length(sch.names) &&
298+
if (length(fitresult.features) == length(sch.names) &&
298299
!all(e -> e in sch.names, fitresult.features))
299300
throw(
300301
ERR_FEATURES_SEEN
@@ -312,7 +313,7 @@ function MMI.save(model::RFE, fitresult)
312313
atomic_fitresult = fitresult.model_fitresult
313314
features_left = fitresult.features_left
314315
features = fitresult.features
315-
316+
316317
atom = model.model
317318
return (
318319
support = copy(support),
@@ -337,14 +338,12 @@ function MMI.restore(model::RFE, serializable_fitresult)
337338
)
338339
end
339340

340-
## Traits definitions
341-
function MMI.load_path(::Type{<:DeterministicRecursiveFeatureElimination})
342-
return "FeatureSelection.DeterministicRecursiveFeatureElimination"
343-
end
341+
## Trait definitions
344342

345-
function MMI.load_path(::Type{<:ProbabilisticRecursiveFeatureElimination})
346-
return "FeatureSelection.ProbabilisticRecursiveFeatureElimination"
347-
end
343+
# load path points to constructor not type:
344+
MMI.load_path(::Type{<:RFE}) = "FeatureSelection.RecursiveFeatureElimination"
345+
MMI.constructor(::Type{<:RFE}) = RecursiveFeatureElimination
346+
MMI.package_name(::Type{<:RFE}) = "FeatureSelection"
348347

349348
for trait in [
350349
:supports_weights,
@@ -387,4 +386,17 @@ end
387386
## TRAINING LOSSES SUPPORT
388387
function MMI.training_losses(model::RFE, rfe_report)
389388
return MMI.training_losses(model.model, rfe_report.model_report)
390-
end
389+
end
390+
391+
## Pkg Traits
392+
MMI.metadata_pkg.(
393+
(
394+
DeterministicRecursiveFeatureElimination,
395+
ProbabilisticRecursiveFeatureElimination,
396+
),
397+
package_name = "FeatureSelection",
398+
package_uuid = "33837fe5-dbff-4c9e-8c2f-c5612fe2b8b6",
399+
package_url = "https://github.com/JuliaAI/FeatureSelection.jl",
400+
is_pure_julia = true,
401+
package_license = "MIT"
402+
)

test/models/rfe.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ const DTM = DummyTestModels
1717
@test_throws FeatureSelection.ERR_SPECIFY_MODEL RecursiveFeatureElimination()
1818
reg = DTM.DeterministicConstantRegressor()
1919
@test_throws(
20-
FeatureSelection.ERR_FEATURE_IMPORTANCE_SUPPORT,
20+
FeatureSelection.ERR_FEATURE_IMPORTANCE_SUPPORT,
2121
RecursiveFeatureElimination(model = DTM.DeterministicConstantRegressor())
2222
)
2323
rf = MLJDecisionTreeInterface.RandomForestRegressor(rng = rng)
2424
selector = RecursiveFeatureElimination(model = rf)
2525
@test selector isa FeatureSelection.DeterministicRecursiveFeatureElimination
26+
@test MLJBase.constructor(selector) == RecursiveFeatureElimination
27+
@test MLJBase.package_name(selector) == "FeatureSelection"
28+
@test MLJBase.load_path(selector) == "FeatureSelection.RecursiveFeatureElimination"
2629

2730
# Fit
2831
selector_mach = machine(selector, Xt, y)
@@ -34,7 +37,7 @@ const DTM = DummyTestModels
3437
selector_mach.model.model, selector_mach.fitresult.model_fitresult
3538
)
3639
@test feature_importances(selector_mach) == [
37-
:x1 => 6.0, :x2 => 5.0, :x3 => 4.0, :x4 => 3.0, :x5 => 2.0,
40+
:x1 => 6.0, :x2 => 5.0, :x3 => 4.0, :x4 => 3.0, :x5 => 2.0,
3841
:x6 => 1.0, :x7 => 1.0, :x8 => 1.0, :x9 => 1.0, :x10 => 1.0
3942
]
4043
rpt = report(selector_mach)
@@ -94,7 +97,7 @@ end
9497
measure = rms,
9598
tuning = Grid(rng=rng),
9699
resampling = StratifiedCV(nfolds = 5),
97-
range = range(rfecv, :n_features, values = 1:10)
100+
range = range(rfecv, :n_features, values = 1:10)
98101
)
99102
self_tuning_rfe_mach = machine(tuning_rfe_model, Xs, ys)
100103
fit!(self_tuning_rfe_mach)
@@ -127,4 +130,4 @@ end
127130
mach2 = MLJBase.machine(io)
128131
close(io)
129132
@test MLJBase.predict(mach2, (; x1=rand(2), x2 = rand(2))) == yhat
130-
end
133+
end

0 commit comments

Comments
 (0)