Skip to content

Commit 9c53b85

Browse files
committed
add Aqua.jl tests and refactor code
1 parent 3e45aba commit 9c53b85

File tree

10 files changed

+223
-161
lines changed

10 files changed

+223
-161
lines changed

Project.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,24 @@ ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
99
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1010

1111
[compat]
12-
julia = "1"
12+
Aqua = "0.8"
13+
Distributions = "0.25"
14+
julia = "1.6"
15+
MLJBase = "1.1"
16+
MLJDecisionTreeInterface = "0.4"
1317
MLJModelInterface = "1.4"
1418
ScientificTypesBase = "3"
19+
StableRNGs = "1"
1520
Tables = "1.2"
21+
Test = "1.6"
1622

1723
[extras]
24+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
1825
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1926
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2027
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"
2128
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2229
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2330

2431
[targets]
25-
test = ["Distributions", "MLJBase", "MLJDecisionTreeInterface", "StableRNGs", "Test"]
32+
test = ["Aqua", "Distributions", "MLJBase", "MLJDecisionTreeInterface", "StableRNGs", "Test"]

README.md

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ known in the R community as the friedman dataset#1. Notice how the target vector
2424
dataset depends on only the first five columns of feature table. So we expect that our
2525
recursive feature elimination should return the first columns as important features.
2626
```julia
27-
using MLJ # or, minimally, `using FeatureSelection, MLJModels, MLJBase`
27+
using MLJ, FeatureSelection
2828
using StableRNGs
2929
rng = StableRNG(123)
3030
A = rand(rng, 50, 10)
@@ -41,20 +41,26 @@ train it on our dataset
4141
RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
4242
forest = RandomForestRegressor()
4343
rfe = RecursiveFeatureElimination(
44-
model = forest, n_features_to_select=5, step=1
44+
model = forest, n_features=5, step=1
4545
) # see doctring for description of defaults
4646
mach = machine(rfe, X, y)
4747
fit!(mach)
4848
```
49-
We can view the important features by inspecting the `fitted_params` object.
49+
If we wish, we can get the feature importance scores, either by inspecting `report(mach)`
50+
or calling the `feature_importances` function on the fitted machine as shown below
51+
```julia
52+
report(mach).ranking # returns [1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
53+
feature_importances(mach) # returns dict of feature => rank pairs
54+
```
55+
We can view the important features used by our model by inspecting the `fitted_params` object.
5056
```julia
5157
p = fitted_params(mach)
5258
p.features_left == [:x1, :x2, :x3, :x4, :x5]
5359
```
54-
We can also call `predict` on the fitted machine, to predict using a
55-
random forest regressor trained just on those features, or call `transform`, to
56-
select just those features some new table including all the original features.
57-
in `?RecursiveFeatureElimination`.
60+
We can also call the `predict` method on the fitted machine, to predict using a
61+
random forest regressor trained using only the important features, or call the `transform`
62+
method, to select just those features from some new table including all the original features.
63+
For more info, type `?RecursiveFeatureElimination` on a Julia REPL.
5864

5965
Okay, let's say that we didn't know that our synthetic dataset depends on only five
6066
columns from our feature table. We could apply cross fold validation `CV(nfolds=5)` with
@@ -66,18 +72,18 @@ rfe = RecursiveFeatureElimination(model = forest)
6672
tuning_rfe_model = TunedModel(
6773
model = rfe,
6874
measure = rms,
69-
tuning = Grid(rng=rng, resolution=10),
70-
resampling = CV(nfolds = 5),
75+
tuning = Grid(rng=rng),
76+
resampling = StratifiedCV(nfolds = 5),
7177
range = range(
72-
rfe, :n_features_to_select, values = collect(2:8)
78+
rfe, :n_features, lower = 1, upper=10, unit=1
7379
)
7480
)
7581
self_tuning_rfe_mach = machine(tuning_rfe_model, X, y)
7682
fit!(self_tuning_rfe_mach)
7783
```
78-
As before we can inspect the important features by inspesting the `fitted_params` object.
84+
As before we can inspect the important features by inspecting the `fitted_params` object.
7985
```julia
80-
fitted_parms(self_tuning_rfe_mach).best_model.features_left == [:x1, :x2, :x3, :x4, :x5]
86+
fitted_params(self_tuning_rfe_mach).best_fitted_params.features_left == [:x1, :x2, :x3, :x4, :x5]
8187
```
8288
and call `predict` on the tuned model machine as shown below
8389
```julia

src/FeatureSelection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module FeatureSelection
22

3-
using MLJModelInterface, Tables, ScientificTypesBase, MLJTuning
3+
using MLJModelInterface, Tables, ScientificTypesBase
44

55
export FeatureSelector, RecursiveFeatureElimination
66

src/models/featureselector.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ MMI.metadata_model(
8484
FeatureSelector,
8585
input_scitype = Table,
8686
output_scitype = Table,
87-
load_path = "MLJModels.FeatureSelector"
87+
load_path = "FeatureSelction.FeatureSelector"
8888
)
8989

9090
"""

src/models/rfe.jl

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@ const ERR_MODEL_TYPE = ArgumentError(
1111
)
1212

1313
const ERR_FEATURE_IMPORTANCE_SUPPORT = ArgumentError(
14-
"Model does not report feature importance, hence recursive feature algorithm "*
15-
"can't be applied."
14+
"Model does not report feature importance, hence recursive feature algorithm "*
15+
"can't be applied."
1616
)
1717

18-
const MODEL_TYPES = [:ProbabilisticRecursiveFeatureElimination, :DeterministicRecursiveFeatureElimination]
18+
const ERR_FEATURES_SEEN = ArgumentError(
19+
"Features of new table must be same as those seen during fit process."
20+
)
21+
22+
const MODEL_TYPES = [
23+
:ProbabilisticRecursiveFeatureElimination, :DeterministicRecursiveFeatureElimination
24+
]
1925
const SUPER_TYPES = [:Deterministic, :Probabilistic]
2026
const MODELTYPE_GIVEN_SUPERTYPES = zip(MODEL_TYPES, SUPER_TYPES)
2127

@@ -114,7 +120,9 @@ RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
114120
115121
# Creates a dataset where the target only depends on the first 5 columns of the input table.
116122
A = rand(rng, 50, 10);
117-
y = 10 .* sin.(pi .* A[:, 1] .* A[:, 2]) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
123+
y = 10 .* sin.(
124+
pi .* A[:, 1] .* A[:, 2]
125+
) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
118126
X = MLJ.table(A);
119127
120128
# fit a rfe model
@@ -189,7 +197,9 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
189197
Xcols = Tables.Columns(X)
190198
features = collect(Tables.columnnames(Xcols))
191199
nfeatures = length(features)
192-
nfeatures < 2 && throw(ArgumentError("The number of features in the feature matrix must be at least 2."))
200+
nfeatures < 2 && throw(
201+
ArgumentError("The number of features in the feature matrix must be at least 2.")
202+
)
193203

194204
# Compute required number of features to select
195205
n_features = selector.n_features # Remember to modify this estimate later
@@ -256,12 +266,12 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
256266
fitresult = (
257267
support = support,
258268
model_fitresult = model_fitresult,
259-
features_left = copy(features_left)
269+
features_left = copy(features_left),
270+
features = features
260271
)
261272
report = (
262273
ranking = ranking,
263-
model_report = model_report,
264-
features = features
274+
model_report = model_report
265275
)
266276

267277
return fitresult, nothing, report
@@ -282,20 +292,27 @@ function MMI.predict(model::RFE, fitresult, X)
282292
end
283293

284294
function MMI.transform(::RFE, fitresult, X)
295+
sch = Tables.schema(Tables.columns(X))
296+
if (length(fitresult.features) == length(sch.names) &&
297+
!all(e -> e in sch.names, fitresult.features))
298+
throw(
299+
ERR_FEATURES_SEEN
300+
)
301+
end
285302
return MMI.selectcols(X, fitresult.features_left)
286303
end
287304

288305
function MMI.feature_importances(::RFE, fitresult, report)
289-
return Pair.(report.features, report.ranking)
306+
return Pair.(fitresult.features, report.ranking)
290307
end
291308

292309
## Traits definitions
293310
function MMI.load_path(::Type{<:DeterministicRecursiveFeatureElimination})
294-
return "FeatureEngineering.DeterministicRecursiveFeatureElimination"
311+
return "FeatureSelection.DeterministicRecursiveFeatureElimination"
295312
end
296313

297314
function MMI.load_path(::Type{<:ProbabilisticRecursiveFeatureElimination})
298-
return "FeatureEngineering.ProbabilisticRecursiveFeatureElimination"
315+
return "FeatureSelection.ProbabilisticRecursiveFeatureElimination"
299316
end
300317

301318
for trait in [
@@ -323,13 +340,17 @@ end
323340

324341
# ## Iteration parameter
325342
# at level of types:
343+
prepend(s::Symbol, ::Nothing) = nothing
344+
prepend(s::Symbol, t::Symbol) = Expr(:(.), s, QuoteNode(t))
345+
prepend(s::Symbol, ex::Expr) = Expr(:(.), prepend(s, ex.args[1]), ex.args[2])
346+
326347
function MMI.iteration_parameter(::Type{<:RFE{M}}) where {M}
327-
return MLJModels.prepend(:model, MMI.iteration_parameter(M))
348+
return prepend(:model, MMI.iteration_parameter(M))
328349
end
329350

330351
# at level of instances:
331352
function MMI.iteration_parameter(model::RFE)
332-
return MLJModels.prepend(:model, MMI.iteration_parameter(model.model))
353+
return prepend(:model, MMI.iteration_parameter(model.model))
333354
end
334355

335356
## TRAINING LOSSES SUPPORT

test/Aqua.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using Aqua
2+
3+
@testset "Aqua.jl" begin
4+
Aqua.test_all(FeatureSelection)
5+
end

test/models/dummy_test_models.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module DummyTestModels
2+
3+
using MLJBase
4+
5+
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
6+
##
7+
8+
struct DeterministicConstantRegressor <: MLJBase.Deterministic end
9+
10+
function MLJBase.fit(::DeterministicConstantRegressor, verbosity::Int, X, y)
11+
fitresult = mean(y)
12+
cache = nothing
13+
report = nothing
14+
return fitresult, cache, report
15+
end
16+
17+
MLJBase.reformat(::DeterministicConstantRegressor, X) = (MLJBase.matrix(X),)
18+
MLJBase.reformat(::DeterministicConstantRegressor, X, y) = (MLJBase.matrix(X), y)
19+
MLJBase.selectrows(::DeterministicConstantRegressor, I, A) = (view(A, I, :),)
20+
MLJBase.selectrows(::DeterministicConstantRegressor, I, A, y) =
21+
(view(A, I, :), y[I])
22+
23+
MLJBase.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
24+
fill(fitresult, nrows(Xnew))
25+
end

test/models/featureselector.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#### FEATURE SELECTOR ####
2+
3+
@testset "Feat Selector" begin
4+
N = 100
5+
X = (
6+
Zn = rand(N),
7+
Crim = rand(N),
8+
x3 = categorical(rand("YN", N)),
9+
x4 = categorical(rand("YN", N))
10+
)
11+
12+
# Test feature selection with `features=Symbol[]`
13+
namesX = MLJBase.schema(X).names |> collect
14+
selector = FeatureSelector()
15+
f, = MLJBase.fit(selector, 1, X)
16+
@test f == namesX
17+
Xt = MLJBase.transform(selector, f, MLJBase.selectrows(X, 1:2))
18+
@test Set(MLJBase.schema(Xt).names) == Set(namesX)
19+
@test length(Xt.Zn) == 2
20+
21+
# Test on selecting features if `features` keyword is defined
22+
selector = FeatureSelector(features=[:Zn, :Crim])
23+
f, = MLJBase.fit(selector, 1, X)
24+
@test MLJBase.transform(selector, f, MLJBase.selectrows(X, 1:2)) ==
25+
MLJBase.select(X, 1:2, [:Zn, :Crim])
26+
27+
# test on ignoring a feature, even if it's listed in the `features`
28+
selector.ignore = true
29+
f, = MLJBase.fit(selector, 1, X)
30+
Xnew = MLJBase.transform(selector, f, X)
31+
@test MLJBase.transform(selector, f, MLJBase.selectrows(X, 1:2)) ==
32+
MLJBase.select(X, 1:2, [:x3, :x4])
33+
34+
# test error about features selected or excluded in fit.
35+
selector = FeatureSelector(features=[:x1, :mickey_mouse])
36+
@test_throws(
37+
ArgumentError,
38+
MLJBase.fit(selector, 1, X)
39+
)
40+
selector.ignore = true
41+
@test_logs(
42+
(:warn, r"Excluding non-existent"),
43+
MLJBase.fit(selector, 1, X)
44+
)
45+
46+
# features must be specified if ignore=true
47+
@test_throws ArgumentError FeatureSelector(ignore=true)
48+
49+
# test logs for no features selected when using Bool-Callable function interface:
50+
selector = FeatureSelector(features= x-> x == (:x1))
51+
@test_throws(
52+
ArgumentError,
53+
MLJBase.fit(selector, 1, X)
54+
)
55+
selector.ignore = true
56+
selector.features = x-> x in [:Zn, :Crim, :x3, :x4]
57+
@test_throws(
58+
ArgumentError,
59+
MLJBase.fit(selector, 1, X)
60+
)
61+
62+
# Test model Metadata
63+
@test MLJBase.input_scitype(selector) == MLJBase.Table
64+
@test MLJBase.output_scitype(selector) == MLJBase.Table
65+
end
66+
67+
# To be added with FeatureSelectorRule X = (n1=["a", "b", "a"], n2=["g", "g", "g"], n3=[7, 8, 9],
68+
# n4 =UInt8[3,5,10], o1=[4.5, 3.6, 4.0], )
69+
# MLJBase.schema(X)
70+
# Xc = coerce(X, :n1=>Multiclass, :n2=>Multiclass)

test/models/rfe.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using .DummyTestModels
2+
const DTM = DummyTestModels
3+
4+
@testset "RecursiveFeatureElimination" begin
5+
#@test_throws ArgumentError RecursiveFeatureElimination(model = rf)
6+
# Data For use in testset
7+
X = rand(rng, 50, 10)
8+
y = @views(
9+
10 .* sin.(
10+
pi .* X[:, 1] .* X[:, 2]
11+
) + 20 .* (X[:, 3] .- 0.5).^ 2 .+ 10 .* X[:, 4] .+ 5 * X[:, 5]
12+
)
13+
Xt = MLJBase.table(X)
14+
Xnew = MLJBase.table(rand(rng, 50, 10))
15+
Xnew2 = MLJBase.table(rand(rng, 50, 10), names = [Symbol("y$i") for i in 1:10])
16+
17+
# Constructor
18+
@test_throws FeatureSelection.ERR_SPECIFY_MODEL RecursiveFeatureElimination()
19+
reg = DTM.DeterministicConstantRegressor()
20+
@test_throws(
21+
FeatureSelection.ERR_FEATURE_IMPORTANCE_SUPPORT,
22+
RecursiveFeatureElimination(model = DTM.DeterministicConstantRegressor())
23+
)
24+
rf = RandomForestRegressor()
25+
selector = RecursiveFeatureElimination(model = rf)
26+
@test selector isa FeatureSelection.DeterministicRecursiveFeatureElimination
27+
28+
# Fit
29+
selector_mach = machine(selector, Xt, y)
30+
fit!(selector_mach)
31+
selector_fp = fitted_params(selector_mach)
32+
@test propertynames(selector_fp) == (:features_left, :model_fitresult)
33+
@test selector_fp.features_left == [:x1, :x2, :x3, :x4, :x5]
34+
@test selector_fp.model_fitresult == MLJBase.fitted_params(
35+
selector_mach.model.model, selector_mach.fitresult.model_fitresult
36+
)
37+
@test feature_importances(selector_mach) == [
38+
:x1 => 1.0, :x2 => 1.0, :x3 => 1.0, :x4 => 1.0, :x5 => 1.0,
39+
:x6 => 2.0, :x7 => 3.0, :x8 => 4.0, :x9 => 5.0, :x10 => 6.0
40+
]
41+
rpt = report(selector_mach)
42+
@test rpt.ranking == [
43+
1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0
44+
]
45+
46+
# predict
47+
yhat = predict(selector_mach, Xnew)
48+
@test scitype(yhat) === AbstractVector{Continuous}
49+
50+
# transform
51+
trf = transform(selector_mach, Xnew)
52+
sch = MLJBase.schema(trf)
53+
@test sch.names === (:x1, :x2, :x3, :x4, :x5)
54+
@test sch.scitypes === (Continuous, Continuous, Continuous, Continuous, Continuous)
55+
@test_throws FeatureSelection.ERR_FEATURES_SEEN transform(selector_mach, Xnew2)
56+
end

0 commit comments

Comments
 (0)