Skip to content

Commit f66fb59

Browse files
committed
clean up of metadata_model() with some default behaviour changes
1 parent d207f88 commit f66fb59

File tree

4 files changed

+69
-89
lines changed

4 files changed

+69
-89
lines changed

src/metadata_utils.jl

Lines changed: 42 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,3 @@
1-
"""
2-
docstring_ext
3-
4-
Internal function to help generate the docstring for a package. See
5-
[`metadata_model`](@ref).
6-
"""
7-
function docstring_ext(T; descr::String="")
8-
package_name = MLJModelInterface.package_name(T)
9-
package_url = MLJModelInterface.package_url(T)
10-
model_name = MLJModelInterface.name(T)
11-
# the message to return
12-
message = "$descr"
13-
message *= "\n→ based on [$package_name]($package_url)."
14-
message *= "\n→ do `@load $model_name pkg=\"$package_name\"` to " *
15-
"use the model."
16-
message *= "\n→ do `?$model_name` for documentation."
17-
end
18-
191
"""
202
metadata_pkg(T; args...)
213
@@ -72,19 +54,31 @@ function metadata_pkg(
7254
parentmodule(T).eval(ex)
7355
end
7456

57+
# Extend `program` (an expression) to include trait definition for
58+
# specified `trait` and type `T`.
59+
function _extend!(program::Expr, trait::Symbol, value, T)
60+
if value !== nothing
61+
push!(program.args, quote
62+
MLJModelInterface.$trait(::Type{<:$T}) = $value
63+
end)
64+
end
65+
end
66+
67+
const WARN_MISSING_LOAD_PATH = "No `load_path` defined. "
68+
69+
7570
"""
7671
metadata_model(`T`; args...)
7772
7873
Helper function to write the metadata for a model `T`.
7974
8075
## Keywords
8176
82-
* `input_scitype=Unknown` : allowed scientific type of the input data
83-
* `target_scitype=Unknown`: allowed sc. type of the target (supervised)
84-
* `output_scitype=Unknown`: allowed sc. type of the transformed data (unsupervised)
85-
* `supports_weights=false` : whether the model supports sample weights
86-
* `docstring=""` : short description of the model
87-
* `load_path=""` : where the model is (usually `PackageName.ModelName`)
77+
* `input_scitype=Unknown`: allowed scientific type of the input data
78+
* `target_scitype=Unknown`: allowed scitype of the target (supervised)
79+
* `output_scitype=Unkonwn`: allowed scitype of the transformed data (unsupervised)
80+
* `supports_weights=false`: whether the model supports sample weights
81+
* `load_path="unknown"`: where the model is (usually `PackageName.ModelName`)
8882
8983
## Example
9084
@@ -93,43 +87,40 @@ metadata_model(KNNRegressor,
9387
input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
9488
target_scitype=AbstractVector{MLJModelInterface.Continuous},
9589
supports_weights=true,
96-
docstring="K-Nearest Neighbors classifier: ...",
9790
load_path="NearestNeighbors.KNNRegressor")
9891
```
9992
"""
10093
function metadata_model(
10194
T;
10295
# aliases:
103-
input=Unknown,
104-
target=Unknown,
105-
output=Unknown,
106-
weights::Bool=false,
107-
descr::String="",
108-
path::String="",
96+
input=nothing,
97+
target=nothing,
98+
output=nothing,
99+
weights::Union{Nothing,Bool}=nothing,
100+
descr::Union{Nothing,String}=nothing,
101+
path::Union{Nothing,String}=nothing,
109102

110103
# preferred names, corresponding to trait names:
111104
input_scitype=input,
112105
target_scitype=target,
113106
output_scitype=output,
114-
supports_weights=weights,
115-
docstring=descr,
116-
load_path=path,
107+
supports_weights::Union{Nothing,Bool}=weights,
108+
docstring::Union{Nothing,String}=descr,
109+
load_path::Union{Nothing,String}=path,
117110
)
118-
if isempty(load_path)
119-
pname = MLJModelInterface.package_name(T)
120-
mname = MLJModelInterface.name(T)
121-
load_path = "MLJModels.$(pname)_.$(mname)"
122-
end
123-
ex = quote
124-
MLJModelInterface.input_scitype(::Type{<:$T}) = $input_scitype
125-
MLJModelInterface.output_scitype(::Type{<:$T}) = $output_scitype
126-
MLJModelInterface.target_scitype(::Type{<:$T}) = $target_scitype
127-
MLJModelInterface.supports_weights(::Type{<:$T}) = $supports_weights
128-
MLJModelInterface.load_path(::Type{<:$T}) = $load_path
129-
130-
function MLJModelInterface.docstring(::Type{<:$T})
131-
return MLJModelInterface.docstring_ext($T; descr=$docstring)
132-
end
133-
end
134-
parentmodule(T).eval(ex)
111+
load_path === nothing && @warn WARN_MISSING_LOAD_PATH
112+
113+
program = quote end
114+
115+
# Note: Naively using metaprogramming to roll up the following
116+
# code does not work. Only change this if you really know what
117+
# you're doing.
118+
_extend!(program, :input_scitype, input_scitype, T)
119+
_extend!(program, :target_scitype, target_scitype, T)
120+
_extend!(program, :output_scitype, output_scitype, T)
121+
_extend!(program, :supports_weights, supports_weights, T)
122+
_extend!(program, :docstring, docstring, T)
123+
_extend!(program, :load_path, load_path, T)
124+
125+
parentmodule(T).eval(program)
135126
end

src/model_traits.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,7 @@ const DeterministicDetector = Union{
1313

1414
const StatTraits = StatisticalTraits
1515

16-
StatTraits.docstring(M::Type{<:MLJType}) = name(M)
17-
18-
function StatTraits.docstring(M::Type{<:Model})
19-
return "$(name(M)) from $(package_name(M)).jl.\n" *
20-
"[Documentation]($(package_url(M)))."
21-
end
16+
StatTraits.docstring(M::Type{<:Model}) = Base.Docs.doc(M) |> string
2217

2318
StatTraits.is_supervised(::Type{<:Supervised}) = true
2419
StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true

test/metadata_utils.jl

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
# poor man's info dict for testing
2-
info_dict(MM::Type{<:Model}) =
3-
Dict(trait => eval(:($trait))(MM) for trait in M.MODEL_TRAITS)
4-
1+
"""Cool model"""
52
@mlj_model mutable struct FooRegressor <: Deterministic
63
a::Int = 0::(_ ≥ 0)
74
b
85
end
96

7+
struct BarGoo <: Deterministic end
8+
109
metadata_pkg(FooRegressor,
1110
name="FooRegressorPkg",
1211
uuid="10745b16-79ce-11e8-11f9-7d13ad32a3b2",
@@ -15,30 +14,40 @@ metadata_pkg(FooRegressor,
1514
license="MIT",
1615
is_wrapper=false
1716
)
17+
18+
# this is added in MLJBase but not in MLJModelInterface, to avoid
19+
# InteractiveUtils as dependency:
20+
setfull()
21+
M.implemented_methods(::FI, M::Type{<:MLJType}) =
22+
getfield.(methodswith(M), :name)
23+
24+
@test_logs(
25+
(:warn, MLJModelInterface.WARN_MISSING_LOAD_PATH),
26+
metadata_model(BarGoo)
27+
)
28+
1829
metadata_model(FooRegressor,
19-
input=Table(Continuous),
20-
target=AbstractVector{Continuous},
21-
descr="La di da")
30+
input_scitype=Table(Continuous),
31+
target_scitype=AbstractVector{Continuous},
32+
load_path="goo goo")
2233

23-
@testset "metadata" begin
24-
setfull()
25-
M.implemented_methods(::FI, M::Type{<:MLJType}) =
26-
getfield.(methodswith(M), :name)
27-
infos = info_dict(FooRegressor)
34+
infos = Dict(trait => eval(:(MLJModelInterface.$trait))(FooRegressor) for
35+
trait in M.MODEL_TRAITS)
2836

37+
@testset "metadata" begin
2938
@test infos[:input_scitype] == Table(Continuous)
3039
@test infos[:output_scitype] == Unknown
3140
@test infos[:target_scitype] == AbstractVector{Continuous}
3241
@test infos[:is_pure_julia]
3342
@test infos[:package_name] == "FooRegressorPkg"
3443
@test infos[:package_license] == "MIT"
35-
@test infos[:load_path] == "MLJModels.FooRegressorPkg_.FooRegressor"
44+
@test infos[:load_path] == "goo goo"
3645
@test infos[:package_uuid] == "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3746
@test infos[:package_url] == "http://existentialcomics.com/"
3847
@test !infos[:is_wrapper]
3948
@test !infos[:supports_weights]
4049
@test !infos[:supports_online]
41-
@test startswith(infos[:docstring], "La di da")
50+
@test infos[:docstring] == "Cool model\n"
4251
@test infos[:name] == "FooRegressor"
4352
@test infos[:is_supervised]
4453
@test infos[:prediction_type] == :deterministic

test/model_traits.jl

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ bar(::P1) = nothing
5151

5252
@test hyperparameter_ranges(md) == (nothing,)
5353

54-
@test docstring(ms) == "S1 from unknown.jl.\n[Documentation](unknown)."
54+
@test contains(docstring(ms), "No documentation")
5555
@test name(ms) == "S1"
5656

5757
@test is_supervised(ms)
@@ -69,31 +69,16 @@ bar(::P1) = nothing
6969
# implemented methods is deferred
7070
setlight()
7171
@test_throws M.InterfaceError implemented_methods(mp)
72-
72+
7373
setfull()
74-
74+
7575
function M.implemented_methods(::FI, M::Type{<:MLJType})
7676
return getfield.(methodswith(M), :name)
7777
end
7878

7979
@test Set(implemented_methods(mp)) == Set([:clean!,:bar,:foo])
8080
end
8181

82-
module Fruit
83-
84-
import MLJModelInterface.MLJType
85-
86-
struct Banana <: MLJType end
87-
88-
end
89-
90-
import .Fruit
91-
92-
@testset "extras" begin
93-
@test docstring(Float64) == "Float64"
94-
@test docstring(Fruit.Banana) == "Banana"
95-
end
96-
9782
@testset "`_density` - helper for predict_scitype fallback" begin
9883
for T in [Continuous, Count, Textual]
9984
@test ==(

0 commit comments

Comments
 (0)