@@ -11,11 +11,17 @@ const ERR_MODEL_TYPE = ArgumentError(
11
11
)
12
12
13
13
const ERR_FEATURE_IMPORTANCE_SUPPORT = ArgumentError (
14
- " Model does not report feature importance, hence recursive feature algorithm " *
15
- " can't be applied."
14
+ " Model does not report feature importance, hence recursive feature algorithm " *
15
+ " can't be applied."
16
16
)
17
17
18
- const MODEL_TYPES = [:ProbabilisticRecursiveFeatureElimination , :DeterministicRecursiveFeatureElimination ]
18
+ const ERR_FEATURES_SEEN = ArgumentError (
19
+ " Features of new table must be same as those seen during fit process."
20
+ )
21
+
22
+ const MODEL_TYPES = [
23
+ :ProbabilisticRecursiveFeatureElimination , :DeterministicRecursiveFeatureElimination
24
+ ]
19
25
const SUPER_TYPES = [:Deterministic , :Probabilistic ]
20
26
const MODELTYPE_GIVEN_SUPERTYPES = zip (MODEL_TYPES, SUPER_TYPES)
21
27
@@ -114,7 +120,9 @@ RandomForestRegressor = @load RandomForestRegressor pkg=DecisionTree
114
120
115
121
# Creates a dataset where the target only depends on the first 5 columns of the input table.
116
122
A = rand(rng, 50, 10);
117
- y = 10 .* sin.(pi .* A[:, 1] .* A[:, 2]) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
123
+ y = 10 .* sin.(
124
+ pi .* A[:, 1] .* A[:, 2]
125
+ ) + 20 .* (A[:, 3] .- 0.5).^ 2 .+ 10 .* A[:, 4] .+ 5 * A[:, 5]);
118
126
X = MLJ.table(A);
119
127
120
128
# fit a rfe model
@@ -189,7 +197,9 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
189
197
Xcols = Tables. Columns (X)
190
198
features = collect (Tables. columnnames (Xcols))
191
199
nfeatures = length (features)
192
- nfeatures < 2 && throw (ArgumentError (" The number of features in the feature matrix must be at least 2." ))
200
+ nfeatures < 2 && throw (
201
+ ArgumentError (" The number of features in the feature matrix must be at least 2." )
202
+ )
193
203
194
204
# Compute required number of features to select
195
205
n_features = selector. n_features # Remember to modify this estimate later
@@ -256,12 +266,12 @@ function MMI.fit(selector::RFE, verbosity::Int, X, y, args...)
256
266
fitresult = (
257
267
support = support,
258
268
model_fitresult = model_fitresult,
259
- features_left = copy (features_left)
269
+ features_left = copy (features_left),
270
+ features = features
260
271
)
261
272
report = (
262
273
ranking = ranking,
263
- model_report = model_report,
264
- features = features
274
+ model_report = model_report
265
275
)
266
276
267
277
return fitresult, nothing , report
@@ -282,20 +292,27 @@ function MMI.predict(model::RFE, fitresult, X)
282
292
end
283
293
284
294
function MMI. transform (:: RFE , fitresult, X)
295
+ sch = Tables. schema (Tables. columns (X))
296
+ if (length (fitresult. features) == length (sch. names) &&
297
+ ! all (e -> e in sch. names, fitresult. features))
298
+ throw (
299
+ ERR_FEATURES_SEEN
300
+ )
301
+ end
285
302
return MMI. selectcols (X, fitresult. features_left)
286
303
end
287
304
288
305
function MMI. feature_importances (:: RFE , fitresult, report)
289
- return Pair .(report . features, report. ranking)
306
+ return Pair .(fitresult . features, report. ranking)
290
307
end
291
308
292
309
# # Traits definitions
293
310
function MMI. load_path (:: Type{<:DeterministicRecursiveFeatureElimination} )
294
- return " FeatureEngineering .DeterministicRecursiveFeatureElimination"
311
+ return " FeatureSelection .DeterministicRecursiveFeatureElimination"
295
312
end
296
313
297
314
function MMI. load_path (:: Type{<:ProbabilisticRecursiveFeatureElimination} )
298
- return " FeatureEngineering .ProbabilisticRecursiveFeatureElimination"
315
+ return " FeatureSelection .ProbabilisticRecursiveFeatureElimination"
299
316
end
300
317
301
318
for trait in [
@@ -323,13 +340,17 @@ end
323
340
324
341
# ## Iteration parameter
325
342
# at level of types:
343
+ prepend (s:: Symbol , :: Nothing ) = nothing
344
+ prepend (s:: Symbol , t:: Symbol ) = Expr (:(.), s, QuoteNode (t))
345
+ prepend (s:: Symbol , ex:: Expr ) = Expr (:(.), prepend (s, ex. args[1 ]), ex. args[2 ])
346
+
326
347
function MMI. iteration_parameter (:: Type{<:RFE{M}} ) where {M}
327
- return MLJModels . prepend (:model , MMI. iteration_parameter (M))
348
+ return prepend (:model , MMI. iteration_parameter (M))
328
349
end
329
350
330
351
# at level of instances:
331
352
function MMI. iteration_parameter (model:: RFE )
332
- return MLJModels . prepend (:model , MMI. iteration_parameter (model. model))
353
+ return prepend (:model , MMI. iteration_parameter (model. model))
333
354
end
334
355
335
356
# # TRAINING LOSSES SUPPORT
0 commit comments