Skip to content

Commit 4845e17

Browse files
authored
Merge pull request #141 from JuliaAI/dev
For a 1.4.0 release
2 parents 4e1a382 + 8b11617 commit 4845e17

File tree

7 files changed

+310
-90
lines changed

7 files changed

+310
-90
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.3.6"
4+
version = "1.4.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -18,9 +18,10 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
1818
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1919
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
2020
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
21+
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2122
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
2223
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2324
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2425

2526
[targets]
26-
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "ScientificTypes", "Tables", "Test"]
27+
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"]

src/MLJModelInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ const MODEL_TRAITS = [
2020
:supports_online,
2121
:docstring,
2222
:name,
23+
:human_name,
2324
:is_supervised,
2425
:prediction_type,
2526
:abstract_type,

src/metadata_utils.jl

Lines changed: 191 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,3 @@
1-
"""
2-
docstring_ext
3-
4-
Internal function to help generate the docstring for a package. See
5-
[`metadata_model`](@ref).
6-
"""
7-
function docstring_ext(T; descr::String="")
8-
package_name = MLJModelInterface.package_name(T)
9-
package_url = MLJModelInterface.package_url(T)
10-
model_name = MLJModelInterface.name(T)
11-
# the message to return
12-
message = "$descr"
13-
message *= "\n→ based on [$package_name]($package_url)."
14-
message *= "\n→ do `@load $model_name pkg=\"$package_name\"` to " *
15-
"use the model."
16-
message *= "\n→ do `?$model_name` for documentation."
17-
end
18-
191
"""
202
metadata_pkg(T; args...)
213
@@ -72,19 +54,30 @@ function metadata_pkg(
7254
parentmodule(T).eval(ex)
7355
end
7456

57+
# Extend `program` (an expression) to include trait definition for
58+
# specified `trait` and type `T`.
59+
function _extend!(program::Expr, trait::Symbol, value, T)
60+
if value !== nothing
61+
push!(program.args, quote
62+
MLJModelInterface.$trait(::Type{<:$T}) = $value
63+
end)
64+
return nothing
65+
end
66+
end
67+
7568
"""
7669
metadata_model(`T`; args...)
7770
7871
Helper function to write the metadata for a model `T`.
7972
8073
## Keywords
8174
82-
* `input_scitype=Unknown` : allowed scientific type of the input data
83-
* `target_scitype=Unknown`: allowed sc. type of the target (supervised)
84-
* `output_scitype=Unknown`: allowed sc. type of the transformed data (unsupervised)
85-
* `supports_weights=false` : whether the model supports sample weights
86-
* `docstring=""` : short description of the model
87-
* `load_path=""` : where the model is (usually `PackageName.ModelName`)
75+
* `input_scitype=Unknown`: allowed scientific type of the input data
76+
* `target_scitype=Unknown`: allowed scitype of the target (supervised)
77+
* `output_scitype=Unkonwn`: allowed scitype of the transformed data (unsupervised)
78+
* `supports_weights=false`: whether the model supports sample weights
79+
* `supports_class_weights=false`: whether the model supports class weights
80+
* `load_path="unknown"`: where the model is (usually `PackageName.ModelName`)
8881
8982
## Example
9083
@@ -93,43 +86,192 @@ metadata_model(KNNRegressor,
9386
input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
9487
target_scitype=AbstractVector{MLJModelInterface.Continuous},
9588
supports_weights=true,
96-
docstring="K-Nearest Neighbors classifier: ...",
9789
load_path="NearestNeighbors.KNNRegressor")
9890
```
9991
"""
10092
function metadata_model(
10193
T;
10294
# aliases:
103-
input=Unknown,
104-
target=Unknown,
105-
output=Unknown,
106-
weights::Bool=false,
107-
descr::String="",
108-
path::String="",
95+
input=nothing,
96+
target=nothing,
97+
output=nothing,
98+
weights::Union{Nothing,Bool}=nothing,
99+
class_weights::Union{Nothing,Bool}=nothing,
100+
descr::Union{Nothing,String}=nothing,
101+
path::Union{Nothing,String}=nothing,
109102

110103
# preferred names, corresponding to trait names:
111104
input_scitype=input,
112105
target_scitype=target,
113106
output_scitype=output,
114-
supports_weights=weights,
115-
docstring=descr,
116-
load_path=path,
107+
supports_weights::Union{Nothing,Bool}=weights,
108+
supports_class_weights::Union{Nothing,Bool}=class_weights,
109+
docstring::Union{Nothing,String}=descr,
110+
load_path::Union{Nothing,String}=path,
111+
human_name::Union{Nothing,String}=nothing
117112
)
118-
if isempty(load_path)
119-
pname = MLJModelInterface.package_name(T)
120-
mname = MLJModelInterface.name(T)
121-
load_path = "MLJModels.$(pname)_.$(mname)"
122-
end
123-
ex = quote
124-
MLJModelInterface.input_scitype(::Type{<:$T}) = $input_scitype
125-
MLJModelInterface.output_scitype(::Type{<:$T}) = $output_scitype
126-
MLJModelInterface.target_scitype(::Type{<:$T}) = $target_scitype
127-
MLJModelInterface.supports_weights(::Type{<:$T}) = $supports_weights
128-
MLJModelInterface.load_path(::Type{<:$T}) = $load_path
129-
130-
function MLJModelInterface.docstring(::Type{<:$T})
131-
return MLJModelInterface.docstring_ext($T; descr=$docstring)
113+
114+
program = quote end
115+
116+
# Note: Naively using metaprogramming to roll up the following
117+
# code does not work. Only change this if you really know what
118+
# you're doing.
119+
_extend!(program, :input_scitype, input_scitype, T)
120+
_extend!(program, :target_scitype, target_scitype, T)
121+
_extend!(program, :output_scitype, output_scitype, T)
122+
_extend!(program, :supports_weights, supports_weights, T)
123+
_extend!(program, :supports_class_weights,supports_class_weights, T)
124+
_extend!(program, :docstring, docstring, T)
125+
_extend!(program, :load_path, load_path, T)
126+
_extend!(program, :human_name, human_name, T)
127+
128+
parentmodule(T).eval(program)
129+
end
130+
131+
# TODO: After `human_name` trait is added as model trait, include in
132+
# example given in the docstring for `doc_header`.
133+
134+
"""
135+
MLJModelInterface.doc_header(SomeModelType)
136+
137+
Return a string suitable for interpolation in the document string of
138+
an MLJ model type. In the example given below, the header expands to
139+
something like this:
140+
141+
> `FooRegressor`
142+
>
143+
>Model type for foo regressor, based on [FooRegressorPkg.jl](http://existentialcomics.com/).
144+
>
145+
>From MLJ, the type can be imported using
146+
>
147+
>
148+
> `FooRegressor = @load FooRegressor pkg=FooRegressorPkg`
149+
>
150+
>Construct an instance with default hyper-parameters using the syntax
151+
>`model = FooRegressor()`. Provide keyword arguments to override
152+
>hyper-parameter defaults, as in `FooRegressor(a=...)`.
153+
154+
Ordinarily, `doc_header` is used in document strings defined *after*
155+
the model type definition, as `doc_header` assumes model traits (in
156+
particular, `package_name` and `package_url`) to be defined; see also
157+
[`MLJModelInterface.metadata_pkg`](@ref).
158+
159+
160+
### Example
161+
162+
Suppose a model type and traits have been defined by:
163+
164+
```
165+
mutable struct FooRegressor
166+
a::Int
167+
b::Float64
168+
end
169+
170+
metadata_pkg(FooRegressor,
171+
name="FooRegressorPkg",
172+
uuid="10745b16-79ce-11e8-11f9-7d13ad32a3b2",
173+
url="http://existentialcomics.com/",
174+
)
175+
metadata_model(FooRegressor,
176+
input=Table(Continuous),
177+
target=AbstractVector{Continuous},
178+
descr="La di da")
179+
```
180+
181+
Then the docstring is defined post-facto with the following code:
182+
183+
```
184+
const HEADER = MLJModelInterface.doc_header(FooRegressor)
185+
186+
@doc \"\"\"
187+
\$HEADER
188+
189+
### Training data
190+
191+
In MLJ or MLJBase, bind an instance `model` ...
192+
193+
<rest of doc string goes here>
194+
195+
\"\"\" FooRegressor
196+
```
197+
198+
"""
199+
function doc_header(SomeModelType)
200+
name = MLJModelInterface.name(SomeModelType)
201+
human_name = MLJModelInterface.human_name(SomeModelType)
202+
package_name = MLJModelInterface.package_name(SomeModelType)
203+
package_url = MLJModelInterface.package_url(SomeModelType)
204+
params = MLJModelInterface.hyperparameters(SomeModelType)
205+
206+
ret =
207+
"""
208+
```
209+
$name
210+
```
211+
212+
Model type for $human_name, based on
213+
[$(package_name).jl]($package_url), and implementing the MLJ
214+
model interface.
215+
216+
From MLJ, the type can be imported using
217+
218+
```
219+
$name = @load $name pkg=$package_name
220+
```
221+
222+
Do `model = $name()` to construct an instance with default hyper-parameters.
223+
""" |> chomp
224+
225+
ret *= " "
226+
227+
isempty(params) && return ret
228+
229+
p = first(params)
230+
ret *=
231+
"""
232+
Provide keyword arguments to override hyper-parameter defaults, as in
233+
`$name($p=...)`.
234+
""" |> chomp
235+
236+
return ret
237+
end
238+
239+
"""
240+
synthesize_docstring
241+
242+
Private method.
243+
244+
Generates a value for the `docstring` trait for use with a model which
245+
does not have a standard document string, to use as the fallback. See
246+
[`metadata_model`](@ref).
247+
248+
"""
249+
function synthesize_docstring(M)
250+
package_name = MLJModelInterface.package_name(M)
251+
package_url = MLJModelInterface.package_url(M)
252+
model_name = MLJModelInterface.name(M)
253+
human_name = MLJModelInterface.human_name(M)
254+
hyperparameters = MLJModelInterface.hyperparameters(M)
255+
256+
# generate text for the section on hyperparameters
257+
text_for_params = ""
258+
if !is_wrapper(M)
259+
model = M()
260+
isempty(hyperparameters) || (text_for_params *= "# Hyper-parameters")
261+
for p in hyperparameters
262+
value = getproperty(model, p)
263+
text_for_params *= "\n\n- `$p = $value`"
132264
end
133265
end
134-
parentmodule(T).eval(ex)
266+
267+
ret = doc_header(M)
268+
if !isempty(text_for_params)
269+
ret *=
270+
"""
271+
272+
$text_for_params
273+
274+
"""
275+
end
276+
return ret
135277
end

src/model_traits.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ const DeterministicDetector = Union{
1313

1414
const StatTraits = StatisticalTraits
1515

16-
StatTraits.docstring(M::Type{<:MLJType}) = name(M)
17-
1816
function StatTraits.docstring(M::Type{<:Model})
19-
return "$(name(M)) from $(package_name(M)).jl.\n" *
20-
"[Documentation]($(package_url(M)))."
17+
docstring = Base.Docs.doc(M) |> string
18+
if occursin("No documentation found", docstring)
19+
docstring = synthesize_docstring(M)
20+
end
21+
return docstring
2122
end
2223

2324
StatTraits.is_supervised(::Type{<:Supervised}) = true

0 commit comments

Comments
 (0)