Skip to content

Commit 9a44da5

Browse files
authored
Merge pull request #86 from alan-turing-institute/dev
For a 0.3.9 release
2 parents c1cd479 + 524b320 commit 9a44da5

File tree

5 files changed

+42
-77
lines changed

5 files changed

+42
-77
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
88
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
9+
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
910

1011
[compat]
1112
ScientificTypes = "^1"
12-
julia = "1"
13+
StatisticalTraits = "^0.1"
14+
julia = "^1"
1315

1416
[extras]
1517
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/MLJModelInterface.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,32 @@
11
module MLJModelInterface
22

3+
const MODEL_TRAITS = [
4+
:input_scitype,
5+
:output_scitype,
6+
:target_scitype,
7+
:is_pure_julia,
8+
:package_name,
9+
:package_license,
10+
:load_path,
11+
:package_uuid,
12+
:package_url,
13+
:is_wrapper,
14+
:supports_weights,
15+
:supports_class_weights,
16+
:supports_online,
17+
:docstring,
18+
:name,
19+
:is_supervised,
20+
:prediction_type,
21+
:implemented_methods,
22+
:hyperparameters,
23+
:hyperparameter_types,
24+
:hyperparameter_ranges]
25+
326
# ------------------------------------------------------------------------
4-
# Dependencies (ScientificTypes itself does not have dependencies)
27+
# Dependencies (ScientificTypes and StatisticalTraits have none)
528
using ScientificTypes
29+
using StatisticalTraits
630
using Random
731

832
# ------------------------------------------------------------------------
@@ -28,13 +52,9 @@ export fit, update, update_data, transform, inverse_transform,
2852
predict_joint, evaluate, clean!, reformat
2953

3054
# model traits
31-
export input_scitype, output_scitype, target_scitype,
32-
is_pure_julia, package_name, package_license,
33-
load_path, package_uuid, package_url,
34-
is_wrapper, supports_weights, supports_online,
35-
docstring, name, is_supervised,
36-
prediction_type, implemented_methods, hyperparameters,
37-
hyperparameter_types, hyperparameter_ranges
55+
for trait in MODEL_TRAITS
56+
@eval(export $trait)
57+
end
3858

3959
# data operations
4060
export matrix, int, classes, decoder, table,
@@ -93,7 +113,6 @@ abstract type JointProbabilistic <: Probabilistic end
93113
# ------------------------------------------------------------------------
94114
# includes
95115

96-
include("utils.jl")
97116
include("parameter_inspection.jl")
98117
include("data_utils.jl")
99118
include("metadata_utils.jl")

src/model_traits.jl

Lines changed: 10 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,18 @@
1-
## MODEL TRAITS
1+
## OVERLOADING TRAIT DEFAULTS RELEVANT TO MODELS
22

3-
# model trait names:
4-
const MODEL_TRAITS = [
5-
:input_scitype, :output_scitype, :target_scitype,
6-
:is_pure_julia, :package_name, :package_license,
7-
:load_path, :package_uuid, :package_url,
8-
:is_wrapper, :supports_weights, :supports_online,
9-
:docstring, :name, :is_supervised,
10-
:prediction_type, :implemented_methods, :hyperparameters,
11-
:hyperparameter_types, :hyperparameter_ranges]
3+
StatisticalTraits.docstring(M::Type{<:MLJType}) = name(M)
4+
StatisticalTraits.docstring(M::Type{<:Model}) =
5+
"$(name(M)) from $(package_name(M)).jl.\n" *
6+
"[Documentation]($(package_url(M)))."
127

13-
for trait in MODEL_TRAITS
14-
ex = quote
15-
$trait(x) = $trait(typeof(x))
16-
end
17-
MLJModelInterface.eval(ex)
18-
end
19-
20-
# fallback trait declarations:
21-
input_scitype(::Type) = Unknown
22-
output_scitype(::Type) = Unknown
23-
target_scitype(::Type) = Unknown # used for measures too
24-
is_pure_julia(::Type) = false
25-
package_name(::Type) = "unknown"
26-
package_license(::Type) = "unknown"
27-
load_path(::Type) = "unknown"
28-
package_uuid(::Type) = "unknown"
29-
package_url(::Type) = "unknown"
30-
is_wrapper(::Type) = false
31-
supports_online(::Type) = false
32-
supports_weights(::Type) = false # used for measures too
33-
hyperparameter_ranges(T::Type) = Tuple(fill(nothing, length(fieldnames(T))))
34-
docstring(M::Type) = string(M)
35-
docstring(M::Type{<:MLJType}) = name(M)
36-
docstring(M::Type{<:Model}) = "$(name(M)) from $(package_name(M)).jl.\n" *
37-
"[Documentation]($(package_url(M)))."
38-
# "derived" traits:
39-
function _coretype(M)
40-
if isdefined(M, :name)
41-
return M.name.name
42-
else
43-
return _coretype(M.body)
44-
end
45-
end
46-
name(M::Type) = string(_coretype(M))
47-
is_supervised(::Type) = false
48-
is_supervised(::Type{<:Supervised}) = true
49-
prediction_type(::Type) = :unknown # used for measures too
50-
prediction_type(::Type{<:Deterministic}) = :deterministic
51-
prediction_type(::Type{<:Probabilistic}) = :probabilistic
52-
prediction_type(::Type{<:Interval}) = :interval
53-
hyperparameters(M::Type) = fieldnames(M)
54-
hyperparameter_types(M::Type) = string.(fieldtypes(M))
8+
StatisticalTraits.is_supervised(::Type{<:Supervised}) = true
9+
StatisticalTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
10+
StatisticalTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
11+
StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
5512

5613
# implementation is deferred as it requires methodswith which depends upon
5714
# InteractiveUtils which we don't want to bring here as a dependency
5815
# (even if it's stdlib).
5916
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
17+
implemented_methods(model) = implemented_methods(typeof(model))
6018
implemented_methods(::LightInterface, M) = errlight("implemented_methods")

src/utils.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

test/model_traits.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ module Fruit
6969
import MLJModelInterface.MLJType
7070

7171
struct Banana <: MLJType end
72-
struct Apple end
7372

7473
end
7574

@@ -78,13 +77,4 @@ import .Fruit
7877
@testset "extras" begin
7978
@test docstring(Float64) == "Float64"
8079
@test docstring(Fruit.Banana) == "Banana"
81-
@test name(Float64) == "Float64"
82-
83-
df = DataFrame(a=randn(2), b=randn(2))
84-
@test M.name(typeof(df)) == "DataFrame"
85-
@test M.name(df) == "DataFrame"
86-
@test M.name(Fruit.Banana) == "Banana"
87-
@test M.name(Fruit.Banana()) == "Banana"
88-
@test M.name(Fruit.Apple) == "Apple"
89-
@test M.name(Fruit.Apple()) == "Apple"
9080
end

0 commit comments

Comments
 (0)