File tree Expand file tree Collapse file tree 4 files changed +29
-3
lines changed
Expand file tree Collapse file tree 4 files changed +29
-3
lines changed Original file line number Diff line number Diff line change 11name = " MLJModelInterface"
22uuid = " e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33authors = [" Thibaut Lienart and Anthony Blaom" ]
4- version = " 1.10 .0"
4+ version = " 1.11 .0"
55
66[deps ]
77Random = " 9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -19,7 +19,7 @@ OrderedCollections = "1"
1919Random = " <0.0.1, 1"
2020ScientificTypes = " 3"
2121ScientificTypesBase = " 3"
22- StatisticalTraits = " 3.3 "
22+ StatisticalTraits = " 3.4 "
2323Tables = " 1"
2424Test = " <0.0.1, 1"
2525julia = " 1.6"
Original file line number Diff line number Diff line change @@ -8,6 +8,7 @@ const MODEL_TRAITS = [
88 :predict_scitype ,
99 :transform_scitype ,
1010 :inverse_transform_scitype ,
11+ :target_in_fit ,
1112 :is_pure_julia ,
1213 :package_name ,
1314 :package_license ,
Original file line number Diff line number Diff line change 3232StatTraits. is_supervised (:: Type{<:Supervised} ) = true
3333StatTraits. is_supervised (:: Type{<:SupervisedAnnotator} ) = true
3434
35+ StatTraits. target_in_fit (:: Type{<:Supervised} ) = true
36+ StatTraits. target_in_fit (:: Type{<:Unsupervised} ) = false
37+
3538StatTraits. prediction_type (:: Type{<:Deterministic} ) = :deterministic
3639StatTraits. prediction_type (:: Type{<:Probabilistic} ) = :probabilistic
3740StatTraits. prediction_type (:: Type{<:Interval} ) = :interval
@@ -73,7 +76,15 @@ function supervised_fit_data_scitype(M)
7376 return ret
7477end
7578
76- StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) = Tuple{input_scitype (M)}
79+ # helper to determine the scitype of unsupervised models
80+ function unsupervised_fit_data_scitype (M)
81+ I = input_scitype (M)
82+ T = target_scitype (M)
83+ target_in_fit (M) && return Tuple{I, T}
84+ return Tuple{I}
85+ end
86+
87+ StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) = unsupervised_fit_data_scitype (M)
7788StatTraits. fit_data_scitype (:: Type{<:Static} ) = Tuple{}
7889StatTraits. fit_data_scitype (M:: Type{<:Supervised} ) = supervised_fit_data_scitype (M)
7990
Original file line number Diff line number Diff line change 2323@mlj_model mutable struct UA <: UnsupervisedAnnotator
2424end
2525
26+ @mlj_model mutable struct SupervisedTransformer <: Unsupervised
27+ end
28+
29+
2630foo (:: P1 ) = 0
2731bar (:: P1 ) = nothing
2832
@@ -34,6 +38,10 @@ M.package_name(::Type{<:U1}) = "Bach"
3438M. package_url (:: Type{<:U1} ) = " www.did_he_write_565.com"
3539M. human_name (:: Type{<:U1} ) = " funky model"
3640
41+ M. target_in_fit (:: Type{<:SupervisedTransformer} ) = true
42+ M. target_scitype (:: Type{<:SupervisedTransformer} ) = Continuous
43+ M. input_scitype (:: Type{<:SupervisedTransformer} ) = Finite
44+
3745@testset " traits" begin
3846 ms = S1 ()
3947 mu = U1 (a= 42 , b= sin)
@@ -42,6 +50,7 @@ M.human_name(::Type{<:U1}) = "funky model"
4250 mi = I1 ()
4351 sa = SA ()
4452 ua = UA ()
53+ supervised_transformer = SupervisedTransformer ()
4554
4655 @test input_scitype (ms) == Unknown
4756 @test output_scitype (ms) == Unknown
@@ -115,6 +124,11 @@ M.human_name(::Type{<:U1}) = "funky model"
115124 setfull ()
116125
117126 @test Set (implemented_methods (mp)) == Set ([:clean! ,:bar ,:foo ])
127+
128+ @test fit_data_scitype (mu) == Tuple{Unknown};;;
129+ @test fit_data_scitype (mu) == Tuple{Unknown}
130+ @test fit_data_scitype (supervised_transformer) == Tuple{Finite,Continuous}
131+
118132end
119133
120134@testset " `_density` - helper for predict_scitype fallback" begin
You can’t perform that action at this time.
0 commit comments