Skip to content

Commit 3a01bc1

Browse files
committed
add trait, target_in_fit
1 parent f730be8 commit 3a01bc1

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

src/MLJModelInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ const MODEL_TRAITS = [
88
:predict_scitype,
99
:transform_scitype,
1010
:inverse_transform_scitype,
11+
:target_in_fit,
1112
:is_pure_julia,
1213
:package_name,
1314
:package_license,

src/model_traits.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ end
3232
StatTraits.is_supervised(::Type{<:Supervised}) = true
3333
StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true
3434

35+
StatTraits.target_in_fit(::Type{<:Supervised}) = true
36+
StatTraits.target_in_fit(::Type{<:Unsupervised}) = false
37+
3538
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
3639
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
3740
StatTraits.prediction_type(::Type{<:Interval}) = :interval
@@ -73,7 +76,15 @@ function supervised_fit_data_scitype(M)
7376
return ret
7477
end
7578

76-
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = Tuple{input_scitype(M)}
79+
# helper to determine the scitype of unsupervised models
80+
function unsupervised_fit_data_scitype(M)
81+
I = input_scitype(M)
82+
T = target_scitype(M)
83+
target_in_fit(M) && return Tuple{I, T}
84+
return Tuple{I}
85+
end
86+
87+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) = unsupervised_fit_data_scitype(M)
7788
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
7889
StatTraits.fit_data_scitype(M::Type{<:Supervised}) = supervised_fit_data_scitype(M)
7990

test/model_traits.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ end
2323
@mlj_model mutable struct UA <: UnsupervisedAnnotator
2424
end
2525

26+
@mlj_model mutable struct SupervisedTransformer <: Unsupervised
27+
end
28+
29+
2630
foo(::P1) = 0
2731
bar(::P1) = nothing
2832

@@ -34,6 +38,10 @@ M.package_name(::Type{<:U1}) = "Bach"
3438
M.package_url(::Type{<:U1}) = "www.did_he_write_565.com"
3539
M.human_name(::Type{<:U1}) = "funky model"
3640

41+
M.target_in_fit(::Type{<:SupervisedTransformer}) = true
42+
M.target_scitype(::Type{<:SupervisedTransformer}) = Continuous
43+
M.input_scitype(::Type{<:SupervisedTransformer}) = Finite
44+
3745
@testset "traits" begin
3846
ms = S1()
3947
mu = U1(a=42, b=sin)
@@ -42,6 +50,7 @@ M.human_name(::Type{<:U1}) = "funky model"
4250
mi = I1()
4351
sa = SA()
4452
ua = UA()
53+
supervised_transformer = SupervisedTransformer()
4554

4655
@test input_scitype(ms) == Unknown
4756
@test output_scitype(ms) == Unknown
@@ -115,6 +124,11 @@ M.human_name(::Type{<:U1}) = "funky model"
115124
setfull()
116125

117126
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
127+
128+
@test fit_data_scitype(mu) == Tuple{Unknown};;;
129+
@test fit_data_scitype(mu) == Tuple{Unknown}
130+
@test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous}
131+
118132
end
119133

120134
@testset "`_density` - helper for predict_scitype fallback" begin

0 commit comments

Comments
 (0)