Skip to content

Commit 82ceb84

Browse files
authored
Merge pull request #118 from davnn/patch-1
Supervised annotators should correctly return `is_supervised`
2 parents eaa93b0 + 2c389e6 commit 82ceb84

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/model_traits.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ StatTraits.docstring(M::Type{<:Model}) =
1414
"$(name(M)) from $(package_name(M)).jl.\n" *
1515
"[Documentation]($(package_url(M)))."
1616

17-
StatTraits.is_supervised(::Type{<:Supervised}) = true
17+
StatTraits.is_supervised(::Type{<:Supervised}) = true
18+
StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true
1819

1920
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
2021
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic

test/model_traits.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ end
1515
@mlj_model mutable struct I1 <: Interval
1616
end
1717

18+
@mlj_model mutable struct SA <: SupervisedAnnotator
19+
end
20+
21+
@mlj_model mutable struct UA <: UnsupervisedAnnotator
22+
end
23+
1824
foo(::P1) = 0
1925
bar(::P1) = nothing
2026

@@ -24,6 +30,8 @@ bar(::P1) = nothing
2430
md = D1()
2531
mp = P1()
2632
mi = I1()
33+
sa = SA()
34+
ua = UA()
2735

2836
@test input_scitype(ms) == Unknown
2937
@test output_scitype(ms) == Unknown
@@ -47,7 +55,9 @@ bar(::P1) = nothing
4755
@test name(ms) == "S1"
4856

4957
@test is_supervised(ms)
58+
@test is_supervised(sa)
5059
@test !is_supervised(mu)
60+
@test !is_supervised(ua)
5161
@test prediction_type(ms) == :unknown
5262
@test prediction_type(md) == :deterministic
5363
@test prediction_type(mp) == :probabilistic

0 commit comments

Comments
 (0)