Skip to content

Commit dcf33d6

Browse files
authored
Merge pull request #169 from JuliaAI/dev
For a 1.8 release
2 parents 13933a8 + d9e9703 commit dcf33d6

File tree

5 files changed

+92
-7
lines changed

5 files changed

+92
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.7.1"
4+
version = "1.8.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

src/MLJModelInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ const ABSTRACT_MODEL_SUBTYPES = [
4141
:Probabilistic,
4242
:Deterministic,
4343
:Interval,
44+
:ProbabilisticSet,
4445
:JointProbabilistic,
4546
:Static,
4647
:Annotator,
@@ -143,6 +144,7 @@ abstract type Annotator <: Model end
143144
abstract type Probabilistic <: Supervised end
144145
abstract type Deterministic <: Supervised end
145146
abstract type Interval <: Supervised end
147+
abstract type ProbabilisticSet <: Supervised end
146148

147149
abstract type JointProbabilistic <: Probabilistic end
148150

src/metadata_utils.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ end
152152
# example given in the docstring for `doc_header`.
153153

154154
"""
155-
MLJModelInterface.doc_header(SomeModelType)
155+
MLJModelInterface.doc_header(SomeModelType; augment=false)
156156
157157
Return a string suitable for interpolation in the document string of
158158
an MLJ model type. In the example given below, the header expands to
@@ -199,7 +199,7 @@ metadata_model(FooRegressor,
199199
descr="La di da")
200200
```
201201
202-
Then the docstring is defined post-facto with the following code:
202+
Then the docstring is defined after these declarations with the following code:
203203
204204
```
205205
\"\"\"
@@ -216,15 +216,34 @@ FooRegressor
216216
217217
```
218218
219+
# Variation to augment existing document string
220+
221+
For models that have a native API with separate documentation, one may want to call
222+
`doc_header(FooRegressor, augment=true)` instead. In that case, the output will look like
223+
this:
224+
225+
>From MLJ, the `FooRegressor` type can be imported using
226+
>
227+
>
228+
> `FooRegressor = @load FooRegressor pkg=FooRegressorPkg`
229+
>
230+
>Construct an instance with default hyper-parameters using the syntax
231+
>`model = FooRegressor()`. Provide keyword arguments to override
232+
>hyper-parameter defaults, as in `FooRegressor(a=...)`.
233+
219234
"""
220-
function doc_header(SomeModelType)
235+
function doc_header(SomeModelType; augment=false)
221236
name = MLJModelInterface.name(SomeModelType)
222237
human_name = MLJModelInterface.human_name(SomeModelType)
223238
package_name = MLJModelInterface.package_name(SomeModelType)
224239
package_url = MLJModelInterface.package_url(SomeModelType)
225240
params = MLJModelInterface.hyperparameters(SomeModelType)
226241

227-
ret =
242+
top = augment ?
243+
"""
244+
From MLJ, the `$name` type can be imported using
245+
246+
""" :
228247
"""
229248
```
230249
$name
@@ -235,7 +254,9 @@ function doc_header(SomeModelType)
235254
model interface.
236255
237256
From MLJ, the type can be imported using
238-
257+
"""
258+
ret = top*
259+
"""
239260
```
240261
$name = @load $name pkg=$package_name
241262
```

src/model_traits.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ StatTraits.is_supervised(::Type{<:SupervisedAnnotator}) = true
2727
StatTraits.prediction_type(::Type{<:Deterministic}) = :deterministic
2828
StatTraits.prediction_type(::Type{<:Probabilistic}) = :probabilistic
2929
StatTraits.prediction_type(::Type{<:Interval}) = :interval
30+
StatTraits.prediction_type(::Type{<:ProbabilisticSet}) = :probabilistic_set
3031
StatTraits.prediction_type(::Type{<:ProbabilisticDetector}) = :probabilistic
3132
StatTraits.prediction_type(::Type{<:DeterministicDetector}) = :deterministic
3233

test/metadata_utils.jl

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ FooRegressor
158158
```
159159
160160
A model type for constructing a foo regressor, based on
161-
[FooRegressorPkg.jl](http://existentialcomics.com/).
161+
[FooRegressorPkg.jl](http://existentialcomics.com/), and implementing the MLJ model
162+
interface.
162163
163164
From MLJ, the type can be imported using
164165
@@ -171,9 +172,69 @@ Provide keyword arguments to override hyper-parameter
171172
defaults, as in `FooRegressor(a=...)`.
172173
""" |> chomp |> Markdown.parse
173174

175+
@test string(header) == string(comparison)
174176
end
175177

176178
@testset "document string" begin
177179
doc = (@doc FooRegressor) |> string |> chomp
178180
@test endswith(doc, "We have no bananas today!")
179181
end
182+
183+
184+
# # DOC STRING - AUGMENTED CASE
185+
186+
"""Cool model"""
187+
@mlj_model mutable struct FooRegressor2 <: Deterministic
188+
a::Int = 0::(_ ≥ 0)
189+
b
190+
end
191+
192+
metadata_pkg(FooRegressor2,
193+
name="FooRegressor2Pkg",
194+
uuid="10745b16-79ce-11e8-11f9-7d13ad32a3b2",
195+
url="http://existentialcomics.com/",
196+
julia=true,
197+
license="MIT",
198+
is_wrapper=false
199+
)
200+
201+
# this is added in MLJBase but not in MLJModelInterface, to avoid
202+
# InteractiveUtils as dependency:
203+
setfull()
204+
M.implemented_methods(::FI, M::Type{<:MLJType}) =
205+
getfield.(methodswith(M), :name)
206+
207+
const HEADER2 = MLJModelInterface.doc_header(FooRegressor2, augment=true)
208+
209+
@doc """
210+
$HEADER2
211+
212+
Yes, we have no bananas. We have no bananas today!
213+
""" FooRegressor2
214+
215+
@testset "doc_header(ModelType, augment=true)" begin
216+
217+
# we test markdown parsed strings for less fussy comparison
218+
219+
header = Markdown.parse(HEADER2)
220+
comparison =
221+
"""
222+
From MLJ, the `FooRegressor2` type can be imported using
223+
224+
```
225+
FooRegressor2 = @load FooRegressor2 pkg=FooRegressor2Pkg
226+
```
227+
228+
Do `model = FooRegressor2()` to construct an instance with default hyper-parameters.
229+
Provide keyword arguments to override hyper-parameter
230+
defaults, as in `FooRegressor2(a=...)`.
231+
""" |> chomp |> Markdown.parse
232+
233+
@test string(header) == string(comparison)
234+
235+
end
236+
237+
@testset "document string" begin
238+
doc = (@doc FooRegressor2) |> string |> chomp
239+
@test endswith(doc, "We have no bananas today!")
240+
end

0 commit comments

Comments
 (0)