Skip to content

Commit db93024

Browse files
authored
Merge pull request #68 from JuliaAI/dev
For a 0.6.3 release
2 parents 387c7c2 + 8d776fb commit db93024

File tree

5 files changed

+22
-3
lines changed

5 files changed

+22
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJIteration"
22
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.6.2"
4+
version = "0.6.3"
55

66
[deps]
77
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"

src/core.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,11 @@ MLJBase.transform(::EitherIteratedModel, fitresult, Xnew) =
155155
# here `fitresult` is a trained atomic machine:
156156
MLJBase.save(::EitherIteratedModel, fitresult) = MLJBase.serializable(fitresult)
157157
MLJBase.restore(::EitherIteratedModel, fitresult) = MLJBase.restore!(fitresult)
158+
159+
# Feature importances
160+
function MLJBase.feature_importances(::EitherIteratedModel, fitresult, report)
161+
# fitresult here is the curent state of the iterated machine
162+
# The line below will return `nothing` when the iteration model doesn't
163+
# support feature_importances.
164+
return MLJBase.feature_importances(fitresult)
165+
end

src/traits.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ for trait in [:supports_weights,
1515
:is_pure_julia,
1616
:input_scitype,
1717
:output_scitype,
18+
:reports_feature_importances,
1819
:target_scitype]
1920
quote
2021
# needed because traits are not always deducable from

test/core.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,10 @@ function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
272272
# if I serialize/deserialized `thing` then `id` below changes:
273273
id = objectid(thing)
274274
fitresult = (thing, id, mean(y))
275-
return fitresult, nothing, NamedTuple()
275+
report = (importances = [ftr => 1.0 for ftr in MLJBase.schema(X).names], )
276+
return fitresult, nothing, report
276277
end
278+
277279
function MLJBase.predict(::EphemeralRegressor, fitresult, X)
278280
thing, id, μ = fitresult
279281
return id == objectid(thing) ? fill(μ, nrows(X)) :
@@ -290,7 +292,12 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
290292
return (thing, id, μ)
291293
end
292294

293-
@testset "save and restore" begin
295+
MLJBase.reports_feature_importances(::Type{<:EphemeralRegressor}) = true
296+
function MLJBase.feature_importances(::EphemeralRegressor, fitresult, report)
297+
return report.importances
298+
end
299+
300+
@testset "feature importances, save and restore" begin
294301
#https://github.com/JuliaAI/MLJ.jl/issues/1099
295302
X, y = (; x = rand(10)), fill(42.0, 3)
296303
controls = [Step(1), NumberLimit(2)]
@@ -302,12 +309,14 @@ end
302309
)
303310
mach = machine(imodel, X, y)
304311
fit!(mach, verbosity=0)
312+
@test MLJBase.feature_importances(mach) == [:x => 1.0];
305313
io = IOBuffer()
306314
MLJBase.save(io, mach)
307315
seekstart(io)
308316
mach2 = machine(io)
309317
close(io)
310318
@test MLJBase.predict(mach2, (; x = rand(2))) fill(42.0, 2)
319+
311320
end
312321

313322
end

test/traits.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ imodel = IteratedModel(model=model, measure=mae)
2323
@test output_scitype(imodel) == output_scitype(model)
2424
@test target_scitype(imodel) == target_scitype(model)
2525
@test constructor(imodel) == IteratedModel
26+
@test reports_feature_importances(imodel) == reports_feature_importances(model)
2627

2728
end
2829

0 commit comments

Comments
 (0)