@@ -14,12 +14,22 @@ mutable struct APIx1 <: Static end
14
14
@test selectrows (APIx0 (), 2 : 3 , X, y) == ((x1 = [4 , 6 ],), [20.0 , 30.0 ])
15
15
end
16
16
17
+ M. metadata_model (
18
+ APIx0,
19
+ supports_training_losses = true ,
20
+ reports_feature_importances = true ,
21
+ )
22
+
23
+ dummy_losses = [1.0 , 2.0 , 3.0 ]
24
+ M. training_losses (:: APIx0 , report) = report
25
+ M. feature_importances (:: APIx0 , fitresult, report) = [:a => 0 , :b => 0 ]
26
+
17
27
@testset " fit-x" begin
18
28
m0 = APIx0 (f0= 1 )
19
29
m1 = APIx0b (f0= 3 )
20
30
# no weight support: fallback
21
- M. fit (m:: APIx0 , v:: Int , X, y) = (5 , nothing , nothing )
22
- @test fit (m0, 1 , randn (2 ), randn (2 ), 5 ) == (5 , nothing , nothing )
31
+ M. fit (m:: APIx0 , v:: Int , X, y) = (5 , nothing , dummy_losses )
32
+ @test fit (m0, 1 , randn (2 ), randn (2 ), 5 ) == (5 , nothing , dummy_losses )
23
33
# with weight support: use
24
34
M. fit (m:: APIx0b , v:: Int , X, y, w) = (7 , nothing , nothing )
25
35
@test fit (m1, 1 , randn (2 ), randn (2 ), 5 ) == (7 , nothing , nothing )
32
42
@test fit (s1, 1 , 0 ) == (nothing , nothing , nothing )
33
43
34
44
# update fallback = fit
35
- @test update (m0, 1 , 5 , nothing , randn (2 ), 5 ) == (5 , nothing , nothing )
45
+ @test update (m0, 1 , 5 , nothing , randn (2 ), 5 ) == (5 , nothing , dummy_losses )
36
46
37
47
# training losses:
38
48
f, c, r = MLJModelInterface. fit (m0, 1 , rand (2 ), rand (2 ))
39
- @test M. training_losses (m0, r) === nothing
40
-
41
- # intrinsic_importances
49
+ @test M. training_losses (m0, r) == dummy_losses
50
+
51
+ # training losses:
52
+ f, c, r = MLJModelInterface. fit (m0, 1 , rand (2 ), rand (2 ))
53
+ @test M. training_losses (m0, r) == dummy_losses
54
+
55
+ # feature_importances
42
56
f, c, r = MLJModelInterface. fit (m0, 1 , rand (2 ), rand (2 ))
43
- MLJModelInterface. reports_feature_importances (:: Type{APIx0} ) = true
44
- MLJModelInterface. feature_importances (:: APIx0 , fitresult, report) = [:a => 0 , :b => 0 ]
45
57
@test MLJModelInterface. feature_importances (m0, f, r) == [:a => 0 , :b => 0 ]
46
58
end
47
59
@@ -67,7 +79,7 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
67
79
end
68
80
69
81
MMI. input_scitype (:: Type{<:UnivariateFiniteFitter} ) = Nothing
70
-
82
+
71
83
MMI. target_scitype (:: Type{<:UnivariateFiniteFitter} ) = AbstractVector{<: Finite }
72
84
73
85
y = categorical (collect (" aabbccaa" ))
0 commit comments