|
58 | 58 | @test predict(smach, X) == predict(mach, X) |
59 | 59 |
|
60 | 60 | rm(filename) |
| 61 | +end |
61 | 62 |
|
| 63 | +# define a supervised model with ephemeral `fitresult`, but which overcomes this by |
| 64 | +# overloading `save`/`restore`: |
| 65 | +thing = [] |
| 66 | +struct EphemeralRegressor <: Deterministic end |
| 67 | +function MLJBase.fit(::EphemeralRegressor, verbosity, X, y) |
| 68 | + # if I serialize/deserialized `thing` then `id` below changes: |
| 69 | + id = objectid(thing) |
| 70 | + fitresult = (thing, id, mean(y)) |
| 71 | + return fitresult, nothing, NamedTuple() |
| 72 | +end |
| 73 | +function MLJBase.predict(::EphemeralRegressor, fitresult, X) |
| 74 | + thing, id, μ = fitresult |
| 75 | + return id == objectid(thing) ? fill(μ, nrows(X)) : |
| 76 | + throw(ErrorException("dead fitresult")) |
| 77 | +end |
| 78 | +MLJBase.target_scitype(::Type{<:EphemeralRegressor}) = AbstractVector{Continuous} |
| 79 | +function MLJBase.save(::EphemeralRegressor, fitresult) |
| 80 | + thing, _, μ = fitresult |
| 81 | + return (thing, μ) |
| 82 | +end |
| 83 | +function MLJBase.restore(::EphemeralRegressor, serialized_fitresult) |
| 84 | + thing, μ = serialized_fitresult |
| 85 | + id = objectid(thing) |
| 86 | + return (thing, id, μ) |
| 87 | +end |
| 88 | + |
| 89 | +@testset "serialization for atomic models with non-persistent fitresults" begin |
| 90 | + # https://github.com/alan-turing-institute/MLJ.jl/issues/1099 |
| 91 | + X, y = (; x = rand(10)), fill(42.0, 3) |
| 92 | + ensemble = EnsembleModel( |
| 93 | + EphemeralRegressor(), |
| 94 | + bagging_fraction=0.7, |
| 95 | + n=2, |
| 96 | + ) |
| 97 | + mach = machine(ensemble, X, y) |
| 98 | + fit!(mach, verbosity=0) |
| 99 | + io = IOBuffer() |
| 100 | + MLJBase.save(io, mach) |
| 101 | + seekstart(io) |
| 102 | + mach2 = machine(io) |
| 103 | + close(io) |
| 104 | + @test MLJBase.predict(mach2, (; x = rand(2))) ≈ fill(42.0, 2) |
62 | 105 | end |
63 | 106 |
|
64 | 107 | end |
|
0 commit comments