Skip to content

Commit ee138b4

Browse files
committed
add test of model learning a distribution
1 parent 3aa1c9c commit ee138b4

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

test/model_api.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,44 @@ mutable struct APIx1 <: Static end
2525
#update fallback = fit
2626
@test update(m0, 1, 5, nothing, randn(2), 5) == (5, nothing, nothing)
2727
end
28+
29+
struct DummyUnivariateFinite end
30+
31+
mutable struct UnivariateFiniteFitter <: Probabilistic end
32+
33+
@testset "models fitting a distribution to data" begin
34+
35+
function MLJModelInterface.fit(model::UnivariateFiniteFitter,
36+
verbosity::Int, X, y)
37+
38+
fitresult = DummyUnivariateFinite()
39+
report = nothing
40+
cache = nothing
41+
42+
verbosity > 0 && @info "Fitted a $fitresult"
43+
44+
return fitresult, cache, report
45+
end
46+
47+
MLJModelInterface.predict(model::UnivariateFiniteFitter,
48+
fitresult,
49+
X) = fill(fitresult, length(X))
50+
51+
MLJModelInterface.input_scitype(::Type{<:UnivariateFiniteFitter}) =
52+
AbstractVector{Nothing}
53+
MLJModelInterface.target_scitype(::Type{<:UnivariateFiniteFitter}) =
54+
AbstractVector{<:Finite}
55+
56+
y =categorical(collect("aabbccaa"))
57+
X = fill(nothing, length(y))
58+
model = UnivariateFiniteFitter()
59+
fitresult, cache, report = MLJModelInterface.fit(model, 1, X, y)
60+
61+
@test cache == nothing
62+
@test report == nothing
63+
64+
ytest = y[1:3]
65+
yhat = predict(model, fitresult, fill(nothing, 3))
66+
@test yhat == fill(DummyUnivariateFinite(), 3)
67+
68+
end

0 commit comments

Comments
 (0)