Skip to content

Commit f740946

Browse files
authored
Merge pull request #121 from davnn/patch-2
Supervised fit should work for all unsupervised annotators
2 parents d58a8a0 + 62cd8ae commit f740946

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

src/model_api.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ fit(::Static, ::Integer, data...) = (nothing, nothing, nothing)
1515
# fallbacks for supervised models that don't support sample weights:
1616
fit(m::Supervised, verbosity, X, y, w) = fit(m, verbosity, X, y)
1717

18-
# fallback for unsupervised detectors when no "evaluation" labels appear:
19-
fit(m::Union{ProbabilisticUnsupervisedDetector,
20-
DeterministicUnsupervisedDetector},
21-
verbosity,
22-
X,
23-
y) = fit(m, verbosity, X)
18+
# fallback for unsupervised annotators when labels or weights appear:
19+
# this is useful for evaluation and mixed composite models that combine
20+
# both supervised and unsupervised annotators
21+
fit(m::UnsupervisedAnnotator, verbosity, X, y) = fit(m, verbosity, X)
22+
fit(m::UnsupervisedAnnotator, verbosity, X, y, w) = fit(m, verbosity, X)
2423

2524
"""
2625
MLJModelInterface.update(model, verbosity, fitresult, cache, data...)

src/model_traits.jl

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ for M in ABSTRACT_MODEL_SUBTYPES
4141
@eval(StatTraits.abstract_type(::Type{<:$M}) = $M)
4242
end
4343

44-
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
45-
Tuple{input_scitype(M)}
46-
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
47-
function StatTraits.fit_data_scitype(M::Type{<:Supervised})
44+
# helper to determine the scitype of supervised models
45+
function supervised_fit_data_scitype(M)
4846
I = input_scitype(M)
4947
T = target_scitype(M)
5048
ret = Tuple{I,T}
@@ -57,21 +55,21 @@ function StatTraits.fit_data_scitype(M::Type{<:Supervised})
5755
end
5856
return ret
5957
end
60-
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
58+
59+
StatTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
6160
Tuple{input_scitype(M)}
61+
StatTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
62+
StatTraits.fit_data_scitype(M::Type{<:Supervised}) =
63+
supervised_fit_data_scitype(M)
64+
65+
# In special case of `UnsupervisedAnnotator`, we allow the target
66+
# as an optional argument to `fit` (that is ignored) so that the
67+
# `machine` constructor will accept it as a valid argument, which
68+
# then enables *evaluation* of the detector with labeled data:
69+
StatTraits.fit_data_scitype(M::Type{<:UnsupervisedAnnotator}) =
70+
Union{Tuple{input_scitype(M)}, supervised_fit_data_scitype(M)}
6271
StatTraits.fit_data_scitype(M::Type{<:SupervisedAnnotator}) =
63-
Tuple{input_scitype(M),target_scitype(M)}
64-
65-
# In special case of `UnsupervisedProbabilisticDetector`, and
66-
# `UnsupervsedDeterministicDetector` we allow the target as an
67-
# optional argument to `fit` (that is ignored) so that the `machine`
68-
# constructor will accept it as a valid argument, which then enables
69-
# *evaluation* of the detector with labeled data:
70-
StatTraits.fit_data_scitype(M::Type{<:Union{
71-
ProbabilisticUnsupervisedDetector,
72-
DeterministicUnsupervisedDetector}}) =
73-
Union{Tuple{input_scitype(M)},
74-
Tuple{input_scitype(M),target_scitype(M)}}
72+
supervised_fit_data_scitype(M)
7573

7674
StatTraits.transform_scitype(M::Type{<:Unsupervised}) =
7775
output_scitype(M)
@@ -82,7 +80,6 @@ StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
8280
StatTraits.predict_scitype(M::Type{<:Union{
8381
Deterministic,DeterministicDetector}}) = target_scitype(M)
8482

85-
8683
## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
8784
## `ProbabilisticDetector` MODELS
8885

0 commit comments

Comments
 (0)