@@ -100,7 +100,7 @@ part of the tuple returned by `fit`.
100
100
101
101
"""
102
102
fitted_params (:: Model , fitresult) = (fitresult= fitresult,)
103
-
103
+ fitted_params ( :: Static , :: Nothing ) = nothing
104
104
"""
105
105
106
106
predict(model, fitresult, new_data...)
@@ -173,6 +173,8 @@ the feature importances from the model's `fitresult` and `report` as an
173
173
abstract vector of `feature::Symbol => importance::Real` pairs
174
174
(e.g `[:gender =>0.23, :height =>0.7, :weight => 0.1]`).
175
175
176
+ # New model implementations
177
+
176
178
The following trait overload is also required:
177
179
`MLJModelInterface.reports_feature_importances(::Type{<:M}) = true`
178
180
@@ -182,3 +184,56 @@ If for some reason a model is sometimes unable to report feature importances the
182
184
183
185
"""
184
186
function feature_importances end
187
+
188
+ _named_tuple (named_tuple:: NamedTuple ) = named_tuple
189
+ _named_tuple (:: Nothing ) = NamedTuple ()
190
+ _named_tuple (something_else) = (report= something_else,)
191
+ _scrub (x) = x
192
+ _scrub (x:: NamedTuple ) = isempty (x) ? nothing : x
193
+ _keys (named_tuple) = keys (named_tuple)
194
+ _keys (:: Nothing ) = ()
195
+
196
+ """
197
+ MLJModelInterface.report(model, report_given_method)
198
+
199
+ Merge the reports in the dictionary `report_given_method` into a single
200
+ property-accessible object. It is supposed that each key of the dictionary is either
201
+ `:fit` or the name of an operation, such as `:predict` or `:transform`. Each value will be
202
+ the `report` component returned by a training method (`fit` or `update`) dispatched on the
203
+ `model` type, in the case of `:fit`, or the report component returned by an operation that
204
+ supports reporting.
205
+
206
+ # New model implementations
207
+
208
+ Overloading this method is optional, unless the model generates reports that are neither
209
+ named tuples nor `nothing`.
210
+
211
+ Assuming each value in the `report_given_method` dictionary is either a named tuple
212
+ or `nothing`, and there are no conflicts between the keys of the dictionary values
213
+ (the individual reports), the fallback returns the usual named tuple merge of the
214
+ dictionary values, ignoring any `nothing` value. If there is a key conflict, all operation
215
+ reports are first wrapped in a named
216
+ tuple of length one, as in `(predict=predict_report,)`. A `:fit` report is never wrapped.
217
+
218
+ If any dictionary `value` is neither a named tuple nor `nothing`, it is first wrapped as
219
+ `(report=value, )` before merging.
220
+
221
+ """
222
+ function report (model, report_given_method)
223
+
224
+ return_keys = vcat (collect .(_keys .(values (report_given_method)))... )
225
+
226
+ # Note that we want to avoid copying values in each individual report named tuple, and
227
+ # merge the reports in a reproducible order.
228
+
229
+ methods = collect (keys (report_given_method)) |> sort!
230
+ length (methods) == 1 && return _scrub (report_given_method[only (methods)])
231
+ need_to_wrap = return_keys != unique (return_keys)
232
+ reports = map (methods) do method
233
+ tup = _named_tuple (report_given_method[method])
234
+ isempty (tup) ? NamedTuple () :
235
+ (need_to_wrap && method != = :fit ) ? NamedTuple {(method,)} ((tup,)) :
236
+ tup
237
+ end
238
+ return _scrub (merge (reports... ))
239
+ end
0 commit comments