|
1 |
| -## MODEL TRAITS |
| 1 | +## OVERLOADING TRAIT DEFAULTS RELEVANT TO MODELS |
2 | 2 |
|
3 |
| -# model trait names: |
4 |
| -const MODEL_TRAITS = [ |
5 |
| - :input_scitype, :output_scitype, :target_scitype, |
6 |
| - :is_pure_julia, :package_name, :package_license, |
7 |
| - :load_path, :package_uuid, :package_url, |
8 |
| - :is_wrapper, :supports_weights, :supports_online, |
9 |
| - :docstring, :name, :is_supervised, |
10 |
| - :prediction_type, :implemented_methods, :hyperparameters, |
11 |
| - :hyperparameter_types, :hyperparameter_ranges] |
| 3 | +StatisticalTraits.docstring(M::Type{<:MLJType}) = name(M) |
| 4 | +StatisticalTraits.docstring(M::Type{<:Model}) = |
| 5 | + "$(name(M)) from $(package_name(M)).jl.\n" * |
| 6 | + "[Documentation]($(package_url(M)))." |
12 | 7 |
|
13 |
| -for trait in MODEL_TRAITS |
14 |
| - ex = quote |
15 |
| - $trait(x) = $trait(typeof(x)) |
16 |
| - end |
17 |
| - MLJModelInterface.eval(ex) |
18 |
| -end |
19 |
| - |
20 |
| -# fallback trait declarations: |
21 |
| -input_scitype(::Type) = Unknown |
22 |
| -output_scitype(::Type) = Unknown |
23 |
| -target_scitype(::Type) = Unknown # used for measures too |
24 |
| -is_pure_julia(::Type) = false |
25 |
| -package_name(::Type) = "unknown" |
26 |
| -package_license(::Type) = "unknown" |
27 |
| -load_path(::Type) = "unknown" |
28 |
| -package_uuid(::Type) = "unknown" |
29 |
| -package_url(::Type) = "unknown" |
30 |
| -is_wrapper(::Type) = false |
31 |
| -supports_online(::Type) = false |
32 |
| -supports_weights(::Type) = false # used for measures too |
33 |
| -hyperparameter_ranges(T::Type) = Tuple(fill(nothing, length(fieldnames(T)))) |
34 |
| -docstring(M::Type) = string(M) |
35 |
| -docstring(M::Type{<:MLJType}) = name(M) |
36 |
| -docstring(M::Type{<:Model}) = "$(name(M)) from $(package_name(M)).jl.\n" * |
37 |
| - "[Documentation]($(package_url(M)))." |
38 |
| -# "derived" traits: |
39 |
| -function _coretype(M) |
40 |
| - if isdefined(M, :name) |
41 |
| - return M.name.name |
42 |
| - else |
43 |
| - return _coretype(M.body) |
44 |
| - end |
45 |
| -end |
46 |
| -name(M::Type) = string(_coretype(M)) |
47 |
| -is_supervised(::Type) = false |
48 |
| -is_supervised(::Type{<:Supervised}) = true |
49 |
| -prediction_type(::Type) = :unknown # used for measures too |
50 |
| -prediction_type(::Type{<:Deterministic}) = :deterministic |
51 |
| -prediction_type(::Type{<:Probabilistic}) = :probabilistic |
52 |
| -prediction_type(::Type{<:Interval}) = :interval |
53 |
| -hyperparameters(M::Type) = fieldnames(M) |
54 |
| -hyperparameter_types(M::Type) = string.(fieldtypes(M)) |
| 8 | +StatisticalTraits.is_supervised(::Type{<:Supervised}) = true |
| 9 | +StatisticalTraits.prediction_type(::Type{<:Deterministic}) = :deterministic |
| 10 | +StatisticalTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic |
| 11 | +StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval |
55 | 12 |
|
56 | 13 | # implementation is deferred as it requires methodswith which depends upon
|
57 | 14 | # InteractiveUtils which we don't want to bring here as a dependency
|
58 | 15 | # (even if it's stdlib).
|
59 | 16 | implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
|
| 17 | +implemented_methods(model) = implemented_methods(typeof(model)) |
60 | 18 | implemented_methods(::LightInterface, M) = errlight("implemented_methods")
|
0 commit comments