Skip to content

Commit 0a3afe8

Browse files
committed
synched with alpha commit of StatTraits
stattraits -> statisticaltraits update project ooops
1 parent 62ba51f commit 0a3afe8

File tree

4 files changed

+14
-53
lines changed

4 files changed

+14
-53
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ version = "0.3.8"
66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
9+
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
910

1011
[compat]
1112
ScientificTypes = "^1"
12-
julia = "1"
13+
StatisticalTraits = "^0.1"
14+
julia = "^1"
1315

1416
[extras]
1517
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/MLJModelInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
module MLJModelInterface
22

33
# ------------------------------------------------------------------------
4-
# Dependencies (ScientificTypes itself does not have dependencies)
4+
# Dependencies (ScientificTypes and StatisticalTraits have none)
55
using ScientificTypes
6+
using StatisticalTraits
67
using Random
78

89
# ------------------------------------------------------------------------

src/model_traits.jl

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,19 @@ const MODEL_TRAITS = [
1010
:prediction_type, :implemented_methods, :hyperparameters,
1111
:hyperparameter_types, :hyperparameter_ranges]
1212

13-
for trait in MODEL_TRAITS
14-
ex = quote
15-
$trait(x) = $trait(typeof(x))
16-
end
17-
MLJModelInterface.eval(ex)
18-
end
13+
StatTraits.docstring(M::Type{<:MLJType}) = name(M)
14+
StatTraits.docstring(M::Type{<:Model}) =
15+
"$(name(M)) from $(package_name(M)).jl.\n" *
16+
"[Documentation]($(package_url(M)))."
1917

20-
# fallback trait declarations:
21-
input_scitype(::Type) = Unknown
22-
output_scitype(::Type) = Unknown
23-
target_scitype(::Type) = Unknown # used for measures too
24-
is_pure_julia(::Type) = false
25-
package_name(::Type) = "unknown"
26-
package_license(::Type) = "unknown"
27-
load_path(::Type) = "unknown"
28-
package_uuid(::Type) = "unknown"
29-
package_url(::Type) = "unknown"
30-
is_wrapper(::Type) = false
31-
supports_online(::Type) = false
32-
supports_weights(::Type) = false # used for measures too
33-
hyperparameter_ranges(T::Type) = Tuple(fill(nothing, length(fieldnames(T))))
34-
docstring(M::Type) = string(M)
35-
docstring(M::Type{<:MLJType}) = name(M)
36-
docstring(M::Type{<:Model}) = "$(name(M)) from $(package_name(M)).jl.\n" *
37-
"[Documentation]($(package_url(M)))."
38-
# "derived" traits:
39-
function _coretype(M)
40-
if isdefined(M, :name)
41-
return M.name.name
42-
else
43-
return _coretype(M.body)
44-
end
45-
end
46-
name(M::Type) = string(_coretype(M))
47-
is_supervised(::Type) = false
48-
is_supervised(::Type{<:Supervised}) = true
49-
prediction_type(::Type) = :unknown # used for measures too
50-
prediction_type(::Type{<:Deterministic}) = :deterministic
51-
prediction_type(::Type{<:Probabilistic}) = :probabilistic
52-
prediction_type(::Type{<:Interval}) = :interval
53-
hyperparameters(M::Type) = fieldnames(M)
54-
hyperparameter_types(M::Type) = string.(fieldtypes(M))
18+
StatTraits.is_supervised(::Type{<:Supervised}) = true
19+
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
20+
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
21+
StatTraits.prediction_type(::Type{<:Interval}) = :interval
5522

5623
# implementation is deferred as it requires methodswith which depends upon
5724
# InteractiveUtils which we don't want to bring here as a dependency
5825
# (even if it's stdlib).
5926
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
27+
implemented_methods(model) = implemented_methods(typeof(model))
6028
implemented_methods(::LightInterface, M) = errlight("implemented_methods")

test/model_traits.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ module Fruit
6969
import MLJModelInterface.MLJType
7070

7171
struct Banana <: MLJType end
72-
struct Apple end
7372

7473
end
7574

@@ -78,13 +77,4 @@ import .Fruit
7877
@testset "extras" begin
7978
@test docstring(Float64) == "Float64"
8079
@test docstring(Fruit.Banana) == "Banana"
81-
@test name(Float64) == "Float64"
82-
83-
df = DataFrame(a=randn(2), b=randn(2))
84-
@test M.name(typeof(df)) == "DataFrame"
85-
@test M.name(df) == "DataFrame"
86-
@test M.name(Fruit.Banana) == "Banana"
87-
@test M.name(Fruit.Banana()) == "Banana"
88-
@test M.name(Fruit.Apple) == "Apple"
89-
@test M.name(Fruit.Apple()) == "Apple"
9080
end

0 commit comments

Comments
 (0)