@@ -41,10 +41,8 @@ for M in ABSTRACT_MODEL_SUBTYPES
41
41
@eval (StatTraits. abstract_type (:: Type{<:$M} ) = $ M)
42
42
end
43
43
44
- StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) =
45
- Tuple{input_scitype (M)}
46
- StatTraits. fit_data_scitype (:: Type{<:Static} ) = Tuple{}
47
- function StatTraits. fit_data_scitype (M:: Type{<:Supervised} )
44
+ # helper to determine the scitype of supervised models
45
+ function supervised_fit_data_scitype (M)
48
46
I = input_scitype (M)
49
47
T = target_scitype (M)
50
48
ret = Tuple{I,T}
@@ -57,21 +55,21 @@ function StatTraits.fit_data_scitype(M::Type{<:Supervised})
57
55
end
58
56
return ret
59
57
end
60
- StatTraits. fit_data_scitype (M:: Type{<:UnsupervisedAnnotator} ) =
58
+
59
+ StatTraits. fit_data_scitype (M:: Type{<:Unsupervised} ) =
61
60
Tuple{input_scitype (M)}
61
+ StatTraits. fit_data_scitype (:: Type{<:Static} ) = Tuple{}
62
+ StatTraits. fit_data_scitype (M:: Type{<:Supervised} ) =
63
+ supervised_fit_data_scitype (M)
64
+
65
+ # In special case of `UnsupervisedAnnotator`, we allow the target
66
+ # as an optional argument to `fit` (that is ignored) so that the
67
+ # `machine` constructor will accept it as a valid argument, which
68
+ # then enables *evaluation* of the detector with labeled data:
69
+ StatTraits. fit_data_scitype (M:: Type{<:UnsupervisedAnnotator} ) =
70
+ Union{Tuple{input_scitype (M)}, supervised_fit_data_scitype (M)}
62
71
StatTraits. fit_data_scitype (M:: Type{<:SupervisedAnnotator} ) =
63
- Tuple{input_scitype (M),target_scitype (M)}
64
-
65
- # In special case of `UnsupervisedProbabilisticDetector`, and
66
- # `UnsupervsedDeterministicDetector` we allow the target as an
67
- # optional argument to `fit` (that is ignored) so that the `machine`
68
- # constructor will accept it as a valid argument, which then enables
69
- # *evaluation* of the detector with labeled data:
70
- StatTraits. fit_data_scitype (M:: Type {<: Union {
71
- ProbabilisticUnsupervisedDetector,
72
- DeterministicUnsupervisedDetector}}) =
73
- Union{Tuple{input_scitype (M)},
74
- Tuple{input_scitype (M),target_scitype (M)}}
72
+ supervised_fit_data_scitype (M)
75
73
76
74
StatTraits. transform_scitype (M:: Type{<:Unsupervised} ) =
77
75
output_scitype (M)
@@ -82,7 +80,6 @@ StatTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
82
80
StatTraits. predict_scitype (M:: Type {<: Union {
83
81
Deterministic,DeterministicDetector}}) = target_scitype (M)
84
82
85
-
86
83
# # FALLBACKS FOR `predict_scitype` FOR `Probabilistic` and
87
84
# # `ProbabilisticDetector` MODELS
88
85
0 commit comments