Skip to content

Commit adb1258

Browse files
authored
Merge pull request #163 from JuliaAI/dev
For a 1.7.0 release
2 parents 3571861 + a346af9 commit adb1258

File tree

7 files changed

+118
-14
lines changed

7 files changed

+118
-14
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.0'
2120
- '1.6'
2221
- '1'
2322
os:

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModelInterface"
22
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
33
authors = ["Thibaut Lienart and Anthony Blaom"]
4-
version = "1.6.0"
4+
version = "1.7.0"
55

66
[deps]
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -11,17 +11,18 @@ StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"
1111
[compat]
1212
ScientificTypesBase = "3.0"
1313
StatisticalTraits = "3.2"
14-
julia = "1"
14+
julia = "1.6"
1515

1616
[extras]
1717
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
1818
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1919
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
2020
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2121
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
22+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2223
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
2324
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2425
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2526

2627
[targets]
27-
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "ScientificTypes", "Tables", "Test"]
28+
test = ["CategoricalArrays", "DataFrames", "Distances", "InteractiveUtils", "Markdown", "OrderedCollections", "ScientificTypes", "Tables", "Test"]

src/metadata_utils.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,16 @@ function _extend!(program::Expr, trait::Symbol, value, T)
6565
end
6666
end
6767

68-
const DEPWARN_DOCSTRING =
69-
"`metadata_model` should not be called with the keyword argument "*
70-
"`descr` or `docstring`. Implementers of the MLJ model interface "*
71-
"should instead create an MLJ-compliant docstring in the usual way. "*
72-
"See https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings for details. "
68+
depwarn_docstring(T) =
69+
"""
70+
71+
Regarding $T: `metadata_model` should not be called with the keyword argument `descr`
72+
or `docstring`. Implementers of the MLJ model interface should instead create an
73+
MLJ-compliant docstring in the usual way. See
74+
https://alan-turing-institute.github.io/MLJ.jl/dev/adding_models_for_general_use/#Document-strings
75+
for details.
76+
77+
"""
7378

7479
"""
7580
metadata_model(T; args...)
@@ -122,7 +127,7 @@ function metadata_model(
122127
supports_training_losses::Union{Nothing,Bool}=nothing,
123128
reports_feature_importances::Union{Nothing,Bool}=nothing,
124129
)
125-
docstring === nothing || Base.depwarn(DEPWARN_DOCSTRING, :metadata_model)
130+
docstring === nothing || Base.depwarn(depwarn_docstring(T), :metadata_model)
126131

127132
program = quote end
128133

src/model_api.jl

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ part of the tuple returned by `fit`.
100100
101101
"""
102102
fitted_params(::Model, fitresult) = (fitresult=fitresult,)
103-
103+
fitted_params(::Static, ::Nothing) = nothing
104104
"""
105105
106106
predict(model, fitresult, new_data...)
@@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an
173173
abstract vector of `feature::Symbol => importance::Real` pairs
174174
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
175175
176+
# New model implementations
177+
176178
The following trait overload is also required:
177179
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`
178180
@@ -182,3 +184,56 @@ If for some reason a model is sometimes unable to report feature importances the
182184
183185
"""
184186
function feature_importances end
187+
188+
_named_tuple(named_tuple::NamedTuple) = named_tuple
189+
_named_tuple(::Nothing) = NamedTuple()
190+
_named_tuple(something_else) = (report=something_else,)
191+
_scrub(x) = x
192+
_scrub(x::NamedTuple) = isempty(x) ? nothing : x
193+
_keys(named_tuple) = keys(named_tuple)
194+
_keys(::Nothing) = ()
195+
196+
"""
197+
MLJModelInterface.report(model, report_given_method)
198+
199+
Merge the reports in the dictionary `report_given_method` into a single
200+
property-accessible object. It is supposed that each key of the dictionary is either
201+
`:fit` or the name of an operation, such as `:predict` or `:transform`. Each value will be
202+
the `report` component returned by a training method (`fit` or `update`) dispatched on the
203+
`model` type, in the case of `:fit`, or the report component returned by an operation that
204+
supports reporting.
205+
206+
# New model implementations
207+
208+
Overloading this method is optional, unless the model generates reports that are neither
209+
named tuples nor `nothing`.
210+
211+
Assuming each value in the `report_given_method` dictionary is either a named tuple
212+
or `nothing`, and there are no conflicts between the keys of the dictionary values
213+
(the individual reports), the fallback returns the usual named tuple merge of the
214+
dictionary values, ignoring any `nothing` value. If there is a key conflict, all operation
215+
reports are first wrapped in a named
216+
tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped.
217+
218+
If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as
219+
`(report=value, )` before merging.
220+
221+
"""
222+
function report(model, report_given_method)
223+
224+
return_keys = vcat(collect.(_keys.(values(report_given_method)))...)
225+
226+
# Note that we want to avoid copying values in each individual report named tuple, and
227+
# merge the reports in a reproducible order.
228+
229+
methods = collect(keys(report_given_method)) |> sort!
230+
length(methods) == 1 && return _scrub(report_given_method[only(methods)])
231+
need_to_wrap = return_keys != unique(return_keys)
232+
reports = map(methods) do method
233+
tup = _named_tuple(report_given_method[method])
234+
isempty(tup) ? NamedTuple() :
235+
(need_to_wrap && method !== :fit) ? NamedTuple{(method,)}((tup,)) :
236+
tup
237+
end
238+
return _scrub(merge(reports...))
239+
end

test/data_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,9 @@ end
323323
eval(:(module UserSide
324324
import MLJModelInterface: metadata_model, metadata_pkg
325325
struct A end
326-
descr = "something"
326+
human_name = "Big Foot"
327327
# Smoke tests.
328-
metadata_model(A; descr=descr)
328+
metadata_model(A; human_name)
329329
metadata_pkg(A)
330330
end))
331331
end

test/model_api.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ end
55
f0::Int
66
end
77

8-
98
mutable struct APIx1 <: Static end
109

1110
@testset "selectrows(model, data...)" begin
@@ -95,3 +94,47 @@ mutable struct UnivariateFiniteFitter <: Probabilistic end
9594
@test yhat == fill(DummyUnivariateFinite(), 3)
9695

9796
end
97+
98+
@testset "fallback for `report()` method" begin
99+
report_given_method =
100+
OrderedCollections.OrderedDict(
101+
:predict=>(y=7,),
102+
:fit=>(x=1, z=3),
103+
:transform=>nothing,
104+
)
105+
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
106+
(x=1, z=3, y=7)
107+
108+
report_given_method =
109+
OrderedCollections.OrderedDict(
110+
:predict=>(y=7,),
111+
:fit=>(y=1, z=3),
112+
:transform=>nothing,
113+
)
114+
@test MLJModelInterface.report(APIx0(f0=1), report_given_method) ==
115+
(y=1, z=3, predict=(y=7,))
116+
117+
@test MLJModelInterface.report(
118+
APIx0(f0=1),
119+
OrderedCollections.OrderedDict(:fit => nothing, :transform => NamedTuple()),
120+
) |> isnothing
121+
122+
@test MLJModelInterface.report(
123+
APIx0(f0=1),
124+
OrderedCollections.OrderedDict(:fit => 42),
125+
) == 42
126+
127+
@test MLJModelInterface.report(
128+
APIx0(f0=1),
129+
OrderedCollections.OrderedDict(:fit => nothing),
130+
) |> isnothing
131+
132+
@test MLJModelInterface.report(
133+
APIx0(f0=1),
134+
OrderedCollections.OrderedDict(:fit => NamedTuple()),
135+
) |> isnothing
136+
137+
138+
end
139+
140+

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ScientificTypesBase, ScientificTypes
33
using Tables, Distances, CategoricalArrays, InteractiveUtils
44
import DataFrames: DataFrame
55
import Markdown
6+
import OrderedCollections
67

78
const M = MLJModelInterface
89
const FI = M.FullInterface

0 commit comments

Comments
 (0)