@@ -65,6 +65,51 @@ function model_type(T, mod; throw=false, verbosity=1)
65
65
return model_type, outcome
66
66
end
67
67
68
+ # helpers:
69
+ ismissing_or_isa (x, T) = ismissing (x) || x isa T
70
+ bad_trait (model_type) = " $model_type has a bad trait declaration.\n "
71
+
72
+ const err_is_pure_julia (model_type) = ErrorException (
73
+ bad_trait (model_type)* " `is_pure_julia` must return `true` or `false`. "
74
+ )
75
+ const err_supports_weights (model_type) = ErrorException (
76
+ bad_trait (model_type)* " `supports_weights` must return `true`, `false` or `missing`. "
77
+ )
78
+ const err_supports_class_weights (model_type) = ErrorException (
79
+ bad_trait (model_type)* " `supports__class_weights` must return `true`, `false` or `missing`. "
80
+ )
81
+ const err_is_wrapper (model_type) = ErrorException (
82
+ bad_trait (model_type)* " `is_wrapper` must return `true` or `false`. "
83
+ )
84
+ const err_package_name (model_type) = ErrorException (
85
+ bad_trait (model_type)* " `package_name` must return a `String`. "
86
+ )
87
+ const err_packge_license (model_type) = ErrorException (
88
+ bad_trait (model_type)* " `package_license` must return a `String`. "
89
+ )
90
+ const err_iteration_parameter (model_type) = ErrorException (
91
+ bad_trait (model_type)* " `iteration_parameter` must return a `Symbol` or `nothing`. "
92
+ )
93
+
94
+ function traits (model_type; throw= false , verbosity= 1 )
95
+ message = " [:traits] Apply smoke test to some model traits"
96
+ attempt (finalize (message, verbosity); throw) do
97
+ ismissing_or_isa (MLJBase. is_pure_julia (model_type), Bool) ||
98
+ throw (err_is_pure_julia (model_type))
99
+ ismissing_or_isa (MLJBase. supports_weights (model_type), Bool) ||
100
+ throw (err_supports_ (model_type))
101
+ ismissing_or_isa (MLJBase. supports_class_weights (model_type), Bool) ||
102
+ throw (err_supports_class_weights (model_type))
103
+ MLJBase. package_name (model_type) isa String ||
104
+ throw (err_package_name (model_type))
105
+ MLJBase. package_license (model_type) isa String ||
106
+ throw (err_package_license (model_type))
107
+ MLJBase. iteration_parameter (model_type) isa Union{Nothing,Symbol} ||
108
+ throw (err_iteration_parameter (model_type))
109
+ nothing
110
+ end
111
+ end
112
+
68
113
function model_instance (model_type; throw= false , verbosity= 1 )
69
114
message = " [:model_instance] Instantiating default model "
70
115
attempt (finalize (message, verbosity); throw) do
@@ -95,10 +140,21 @@ function operations(fitted_machine, data...; throw=false, verbosity=1)
95
140
methods = MLJBase. implemented_methods (fitted_machine. model)
96
141
_, test = MLJBase. partition (1 : MLJBase. nrows (first (data)), 0.01 )
97
142
if :predict in methods
98
- predict (fitted_machine, first (data))
143
+ yhat = predict (fitted_machine, first (data))
99
144
model isa Static || predict (fitted_machine, rows= test)
100
145
model isa Static || predict (fitted_machine, rows= :)
101
146
push! (operations, " predict" )
147
+
148
+ # check for double wrapped CategoricalValues in predict output for
149
+ # classifiers:
150
+ if target_scitype (model) <: AbstractVector{<:Finite} &&
151
+ model isa Union{Deterministic,Probabilistic}
152
+ η = model isa Deterministic ? first (yhat) : rand (first (yhat))
153
+ unwrap (η) isa MLJBase. CategoricalArrays. CategoricalValue &&
154
+ error (" Doubly wrapped CategoricalValue encountered. Check use of " *
155
+ " CategoricalArrays methods `levels` and `unique`, which changed in " *
156
+ " version 1.0. " )
157
+ end
102
158
end
103
159
if :transform in methods
104
160
W = if model isa Static
0 commit comments