Skip to content

Commit 11240ae

Browse files
authored
Merge pull request #601 from JuliaAI/dev
For a 0.18.2 release - Take 2
2 parents 7d13f34 + 7384cde commit 11240ae

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/model_search.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,15 @@ function localmodeltypes(modl; toplevel=false, wrappers=false)
389389
end
390390
end
391391

392+
# helper
393+
function simple_repr(T)
394+
# get rid of type parameters:
395+
output = split(repr(T), "{") |> first
396+
# get rid of qualifiers:
397+
output = split(output, ".") |> last
398+
return output
399+
end
400+
392401
"""
393402
localmodels(; modl=Main, wrappers=false)
394403
localmodels(filters...; modl=Main, wrappers=false)
@@ -410,11 +419,12 @@ See also [`models`](@ref), [`load_path`](@ref).
410419
411420
"""
412421
function localmodels(args...; modl=Main, kwargs...)
413-
modeltypes = localmodeltypes(modl; kwargs...)
422+
modeltypes = filter(M-> !isabstracttype(M), localmodeltypes(modl; kwargs...))
414423
handles = map(modeltypes) do M
415-
Handle(MMI.name(M), MMI.package_name(M))
424+
name = is_wrapper(M) ? simple_repr(MMI.constructor(M)) : MMI.name(M)
425+
Handle(name, MMI.package_name(M))
416426
end
417-
return filter(models(args...)) do model
427+
ret = filter(models(args...; kwargs...)) do model
418428
Handle(model.name, model.package_name) in handles
419429
end
420430
end

test/model_search.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,5 +151,16 @@ end
151151
@test pca models(r"PCA′")
152152
end
153153

154+
@testset "https://github.com/JuliaAI/MLJModels.jl/issues/594" begin
155+
ms = map(m->m.name, models())
156+
ms_plus = map(m->m.name, models(; wrappers=true))
157+
localms = map(m->m.name, localmodels())
158+
localms_plus = map(m->m.name, localmodels(; wrappers=true))
159+
@test !("Pipeline" in ms)
160+
@test "Pipeline" in ms_plus
161+
@test !("Pipeline" in localms)
162+
@test "Pipeline" in localms_plus
163+
end
164+
154165
end
155166
true

0 commit comments

Comments
 (0)