diff --git a/src/model_search.jl b/src/model_search.jl index b85e9e9..19d6cbc 100644 --- a/src/model_search.jl +++ b/src/model_search.jl @@ -389,6 +389,15 @@ function localmodeltypes(modl; toplevel=false, wrappers=false) end end +# helper +function simple_repr(T) + # get rid of type parameters: + output = split(repr(T), "{") |> first + # get rid of qualifiers: + output = split(output, ".") |> last + return output +end + """ localmodels(; modl=Main, wrappers=false) localmodels(filters...; modl=Main, wrappers=false) @@ -410,11 +419,12 @@ See also [`models`](@ref), [`load_path`](@ref). """ function localmodels(args...; modl=Main, kwargs...) - modeltypes = localmodeltypes(modl; kwargs...) + modeltypes = filter(M-> !isabstracttype(M), localmodeltypes(modl; kwargs...)) handles = map(modeltypes) do M - Handle(MMI.name(M), MMI.package_name(M)) + name = is_wrapper(M) ? simple_repr(MMI.constructor(M)) : MMI.name(M) + Handle(name, MMI.package_name(M)) end - return filter(models(args...)) do model + ret = filter(models(args...; kwargs...)) do model Handle(model.name, model.package_name) in handles end end diff --git a/test/model_search.jl b/test/model_search.jl index 24eeefe..7b4b61f 100644 --- a/test/model_search.jl +++ b/test/model_search.jl @@ -151,5 +151,16 @@ end @test pca ∉ models(r"PCA′") end +@testset "https://github.com/JuliaAI/MLJModels.jl/issues/594" begin + ms = map(m->m.name, models()) + ms_plus = map(m->m.name, models(; wrappers=true)) + localms = map(m->m.name, localmodels()) + localms_plus = map(m->m.name, localmodels(; wrappers=true)) + @test !("Pipeline" in ms) + @test "Pipeline" in ms_plus + @test !("Pipeline" in localms) + @test "Pipeline" in localms_plus +end + end true