Skip to content

Commit 1153104

Browse files
committed
add tags trait
1 parent 6f83420 commit 1153104

File tree

4 files changed

+9
-10
lines changed

4 files changed

+9
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ REPL = "<0.0.1, 1"
2222
Random = "<0.0.1, 1"
2323
ScientificTypes = "3"
2424
ScientificTypesBase = "3"
25-
StatisticalTraits = "3.4"
25+
StatisticalTraits = "3.5"
2626
Tables = "1"
2727
Test = "<0.0.1, 1"
2828
julia = "1.6"

src/MLJModelInterface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const MODEL_TRAITS = [
2222
:docstring,
2323
:name,
2424
:human_name,
25+
:tags,
2526
:is_supervised,
2627
:prediction_type,
2728
:abstract_type,
@@ -64,7 +65,7 @@ using ScientificTypesBase
6465
using StatisticalTraits
6566
using Random
6667
using InteractiveUtils
67-
using REPL # apparently needed to get Base.Docs.doc to work
68+
using REPL # needed to get Base.Docs.doc to work
6869

6970
import StatisticalTraits: info
7071

test/metadata_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
142142
@test infos[:hyperparameters] == (:a, :b)
143143
@test infos[:hyperparameter_types] == ("Int64", "Any")
144144
@test infos[:hyperparameter_ranges] == (nothing, nothing)
145-
@test !infos[:supports_training_losses]
145+
@test !infos[:supports_training_losses]
146146
@test !infos[:reports_feature_importances]
147147
end
148148

test/model_traits.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ bar(::P1) = nothing
3333
M.package_name(::Type{<:S1}) = "Sibelius"
3434
M.package_url(::Type{<:S1}) = "www.find_the_eighth.org"
3535
M.human_name(::Type{<:S1}) = "silly model"
36+
M.tags(::Type{<:S1}) = ["regression", "gradient descent"]
3637

3738
M.package_name(::Type{<:U1}) = "Bach"
3839
M.package_url(::Type{<:U1}) = "www.did_he_write_565.com"
@@ -103,6 +104,7 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite
103104

104105
@test name(ms) == "S1"
105106
@test human_name(ms) == "silly model"
107+
@test tags(ms) == ["regression", "gradient descent"]
106108

107109

108110
@test is_supervised(ms)
@@ -117,18 +119,14 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite
117119
@test hyperparameters(md) == (:a,)
118120
@test hyperparameter_types(md) == ("Int64",)
119121

120-
# implemented methods is deferred
121-
setlight()
122-
@test_throws M.InterfaceError implemented_methods(mp)
123-
124-
setfull()
122+
@test implemented_methods(ms) == [:clean!,]
125123

126124
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
127125

128-
@test fit_data_scitype(mu) == Tuple{Unknown};;;
126+
@test fit_data_scitype(mu) == Tuple{Unknown}
129127
@test fit_data_scitype(mu) == Tuple{Unknown}
130128
@test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous}
131-
129+
132130
end
133131

134132
@testset "`_density` - helper for predict_scitype fallback" begin

0 commit comments

Comments
 (0)