@@ -272,8 +272,10 @@ function MLJBase.fit(::EphemeralRegressor, verbosity, X, y)
272272 # if I serialize/deserialized `thing` then `id` below changes:
273273 id = objectid(thing)
274274 fitresult = (thing, id, mean(y))
275- return fitresult, nothing , NamedTuple()
275+ report = (importances = [ftr => 1.0 for ftr in MLJBase. schema(X). names], )
276+ return fitresult, nothing , report
276277end
278+
277279function MLJBase. predict(:: EphemeralRegressor , fitresult, X)
278280 thing, id, μ = fitresult
279281 return id == objectid(thing) ? fill(μ, nrows(X)) :
@@ -290,7 +292,12 @@ function MLJBase.restore(::EphemeralRegressor, serialized_fitresult)
290292 return (thing, id, μ)
291293end
292294
293- @testset " save and restore" begin
295+ MLJBase. reports_feature_importances(:: Type{<:EphemeralRegressor} ) = true
296+ function MLJBase. feature_importances(:: EphemeralRegressor , fitresult, report)
297+ return report. importances
298+ end
299+
300+ @testset " feature importances, save and restore" begin
294301 # https://github.com/JuliaAI/MLJ.jl/issues/1099
295302 X, y = (; x = rand(10 )), fill(42.0 , 3 )
296303 controls = [Step(1 ), NumberLimit(2 )]
@@ -302,12 +309,14 @@ end
302309 )
303310 mach = machine(imodel, X, y)
304311 fit!(mach, verbosity= 0 )
312+ @test MLJBase. feature_importances(mach) == [:x => 1.0 ];
305313 io = IOBuffer()
306314 MLJBase. save(io, mach)
307315 seekstart(io)
308316 mach2 = machine(io)
309317 close(io)
310318 @test MLJBase. predict(mach2, (; x = rand(2 ))) ≈ fill(42.0 , 2 )
319+
311320end
312321
313322end
0 commit comments