Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.11.1"
version = "1.12.0"

[deps]
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Expand All @@ -21,7 +22,7 @@ REPL = "<0.0.1, 1"
Random = "<0.0.1, 1"
ScientificTypes = "3"
ScientificTypesBase = "3"
StatisticalTraits = "3.4"
StatisticalTraits = "3.5"
Tables = "1"
Test = "<0.0.1, 1"
julia = "1.6"
Expand Down
4 changes: 3 additions & 1 deletion src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const MODEL_TRAITS = [
:docstring,
:name,
:human_name,
:tags,
:is_supervised,
:prediction_type,
:abstract_type,
Expand Down Expand Up @@ -63,7 +64,8 @@ const ABSTRACT_MODEL_SUBTYPES = [
using ScientificTypesBase
using StatisticalTraits
using Random
using REPL # apparently needed to get Base.Docs.doc to work
using InteractiveUtils
using REPL # needed to get Base.Docs.doc to work

import StatisticalTraits: info

Expand Down
7 changes: 3 additions & 4 deletions src/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ function StatTraits.target_scitype(::Type{<:DeterministicDetector})
return AbstractVector{<:Union{Missing, OrderedFactor{2}}}
end

# implementation is deferred as it requires methodswith which depends upon
# InteractiveUtils which we don't want to bring here as a dependency
# (even if it's stdlib).
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
implemented_methods(M::Type) = getfield.(methodswith(M), :name) |> unique
implemented_methods(model) = implemented_methods(typeof(model))

# can be removed in MLJModelInterface 2.0:
implemented_methods(::LightInterface, M) = errlight("implemented_methods")

for M in ABSTRACT_MODEL_SUBTYPES
Expand Down
2 changes: 1 addition & 1 deletion test/metadata_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
@test infos[:hyperparameters] == (:a, :b)
@test infos[:hyperparameter_types] == ("Int64", "Any")
@test infos[:hyperparameter_ranges] == (nothing, nothing)
@test !infos[:supports_training_losses]
@test !infos[:supports_training_losses]
@test !infos[:reports_feature_importances]
end

Expand Down
12 changes: 5 additions & 7 deletions test/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ bar(::P1) = nothing
M.package_name(::Type{<:S1}) = "Sibelius"
M.package_url(::Type{<:S1}) = "www.find_the_eighth.org"
M.human_name(::Type{<:S1}) = "silly model"
M.tags(::Type{<:S1}) = ["regression", "gradient descent"]

M.package_name(::Type{<:U1}) = "Bach"
M.package_url(::Type{<:U1}) = "www.did_he_write_565.com"
Expand Down Expand Up @@ -103,6 +104,7 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite

@test name(ms) == "S1"
@test human_name(ms) == "silly model"
@test tags(ms) == ["regression", "gradient descent"]


@test is_supervised(ms)
Expand All @@ -117,18 +119,14 @@ M.input_scitype(::Type{<:SupervisedTransformer}) = Finite
@test hyperparameters(md) == (:a,)
@test hyperparameter_types(md) == ("Int64",)

# implemented methods is deferred
setlight()
@test_throws M.InterfaceError implemented_methods(mp)

setfull()
@test implemented_methods(ms) == [:clean!,]

@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])

@test fit_data_scitype(mu) == Tuple{Unknown};;;
@test fit_data_scitype(mu) == Tuple{Unknown}
@test fit_data_scitype(mu) == Tuple{Unknown}
@test fit_data_scitype(supervised_transformer) == Tuple{Finite,Continuous}

end

@testset "`_density` - helper for predict_scitype fallback" begin
Expand Down
Loading