@@ -25,3 +25,44 @@ mutable struct APIx1 <: Static end
25
25
# update fallback = fit
26
26
@test update (m0, 1 , 5 , nothing , randn (2 ), 5 ) == (5 , nothing , nothing )
27
27
end
28
+
29
+ struct DummyUnivariateFinite end
30
+
31
+ mutable struct UnivariateFiniteFitter <: Probabilistic end
32
+
33
+ @testset " models fitting a distribution to data" begin
34
+
35
+ function MLJModelInterface. fit (model:: UnivariateFiniteFitter ,
36
+ verbosity:: Int , X, y)
37
+
38
+ fitresult = DummyUnivariateFinite ()
39
+ report = nothing
40
+ cache = nothing
41
+
42
+ verbosity > 0 && @info " Fitted a $fitresult "
43
+
44
+ return fitresult, cache, report
45
+ end
46
+
47
+ MLJModelInterface. predict (model:: UnivariateFiniteFitter ,
48
+ fitresult,
49
+ X) = fill (fitresult, length (X))
50
+
51
+ MLJModelInterface. input_scitype (:: Type{<:UnivariateFiniteFitter} ) =
52
+ AbstractVector{Nothing}
53
+ MLJModelInterface. target_scitype (:: Type{<:UnivariateFiniteFitter} ) =
54
+ AbstractVector{<: Finite }
55
+
56
+ y = categorical (collect (" aabbccaa" ))
57
+ X = fill (nothing , length (y))
58
+ model = UnivariateFiniteFitter ()
59
+ fitresult, cache, report = MLJModelInterface. fit (model, 1 , X, y)
60
+
61
+ @test cache == nothing
62
+ @test report == nothing
63
+
64
+ ytest = y[1 : 3 ]
65
+ yhat = predict (model, fitresult, fill (nothing , 3 ))
66
+ @test yhat == fill (DummyUnivariateFinite (), 3 )
67
+
68
+ end
0 commit comments