diff --git a/Project.toml b/Project.toml index c3ea08f..65fea1f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "MLJModelInterface" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" authors = ["Thibaut Lienart and Anthony Blaom"] -version = "1.11.1" +version = "1.12.0" [deps] +InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" @@ -21,7 +22,7 @@ REPL = "<0.0.1, 1" Random = "<0.0.1, 1" ScientificTypes = "3" ScientificTypesBase = "3" -StatisticalTraits = "3.4" +StatisticalTraits = "3.5" Tables = "1" Test = "<0.0.1, 1" julia = "1.6" diff --git a/src/MLJModelInterface.jl b/src/MLJModelInterface.jl index f4aa890..48998fb 100644 --- a/src/MLJModelInterface.jl +++ b/src/MLJModelInterface.jl @@ -22,6 +22,7 @@ const MODEL_TRAITS = [ :docstring, :name, :human_name, + :tags, :is_supervised, :prediction_type, :abstract_type, @@ -63,7 +64,8 @@ const ABSTRACT_MODEL_SUBTYPES = [ using ScientificTypesBase using StatisticalTraits using Random -using REPL # apparently needed to get Base.Docs.doc to work +using InteractiveUtils +using REPL # needed to get Base.Docs.doc to work import StatisticalTraits: info diff --git a/src/model_traits.jl b/src/model_traits.jl index 0287081..b2db5f8 100644 --- a/src/model_traits.jl +++ b/src/model_traits.jl @@ -50,11 +50,10 @@ function StatTraits.target_scitype(::Type{<:DeterministicDetector}) return AbstractVector{<:Union{Missing, OrderedFactor{2}}} end -# implementation is deferred as it requires methodswith which depends upon -# InteractiveUtils which we don't want to bring here as a dependency -# (even if it's stdlib). -implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M) +implemented_methods(M::Type) = getfield.(methodswith(M), :name) |> unique implemented_methods(model) = implemented_methods(typeof(model)) + +# can be removed in MLJModelInterface 2.0: implemented_methods(::LightInterface, M) = errlight("implemented_methods") for M in ABSTRACT_MODEL_SUBTYPES diff --git a/test/metadata_utils.jl b/test/metadata_utils.jl index 471ef65..668a5ca 100644 --- a/test/metadata_utils.jl +++ b/test/metadata_utils.jl @@ -142,7 +142,7 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for @test infos[:hyperparameters] == (:a, :b) @test infos[:hyperparameter_types] == ("Int64", "Any") @test infos[:hyperparameter_ranges] == (nothing, nothing) - @test !infos[:supports_training_losses] + @test !infos[:supports_training_losses] @test !infos[:reports_feature_importances] end diff --git a/test/model_traits.jl b/test/model_traits.jl index d11227f..31ffd23 100644 --- a/test/model_traits.jl +++ b/test/model_traits.jl @@ -33,6 +33,7 @@ bar(::P1) = nothing M.package_name(::Type{<:S1}) = "Sibelius" M.package_url(::Type{<:S1}) = "www.find_the_eighth.org" M.human_name(::Type{<:S1}) = "silly model" +M.tags(::Type{<:S1}) = ["regression", "gradient descent"] M.package_name(::Type{<:U1}) = "Bach" M.package_url(::Type{<:U1}) = "www.did_he_write_565.com" @@ -103,6 +104,7 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite @test name(ms) == "S1" @test human_name(ms) == "silly model" + @test tags(ms) == ["regression", "gradient descent"] @test is_supervised(ms) @@ -117,18 +119,14 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite @test hyperparameters(md) == (:a,) @test hyperparameter_types(md) == ("Int64",) - # implemented methods is deferred - setlight() - @test_throws M.InterfaceError implemented_methods(mp) - - setfull() + @test implemented_methods(ms) == [:clean!,] @test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo]) - @test fit_data_scitype(mu) == Tuple{Unknown};;; + @test fit_data_scitype(mu) == Tuple{Unknown} @test fit_data_scitype(mu) == Tuple{Unknown} @test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous} - + end @testset "`_density` - helper for predict_scitype fallback" begin