@@ -24,23 +24,38 @@ MLJModelInterface.predict(::ConstantRegressor, fitresult, Xnew) =
2424 fill (fitresult, nrows (Xnew))
2525
2626# #
27- # # THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
27+ # # THE CONSTANT DETERMINISTIC REGRESSOR
2828# #
2929
30+ # helpers:
31+ _mean (y) = _mean (y, scitype (y))
32+ _mean (y, :: Type{<:AbstractArray} ) = mean (y, dims= 1 )
33+ _mean (y, :: Type{<:Table} ) = _mean (Tables. matrix (y), AbstractArray)
34+ _materializer (y) = _materializer (y, scitype (y))
35+ _materializer (y, :: Type{<:AbstractMatrix} ) = identity
36+ _materializer (y, :: Type{<:AbstractVector} ) = vec
37+ _materializer (y, :: Type{<:Table} ) = Tables. materializer (y)∘ Tables. table
38+
3039struct DeterministicConstantRegressor <: Deterministic end
3140
3241function MLJModelInterface. fit (:: DeterministicConstantRegressor ,
3342 verbosity:: Int ,
3443 X,
3544 y)
36- fitresult = mean (y)
45+ μ = _mean (y)
46+ materializer = _materializer (y)
47+ fitresult = (; μ, materializer)
3748 cache = nothing
3849 report = NamedTuple ()
3950 return fitresult, cache, report
4051end
4152
4253MLJModelInterface. predict (:: DeterministicConstantRegressor , fitresult, Xnew) =
43- fill (fitresult, nrows (Xnew))
54+ hcat ([fill (fitresult. μ[i], nrows (Xnew)) for i in eachindex (fitresult. μ)]. .. ) |>
55+ fitresult. materializer
56+
57+ MLJModelInterface. fitted_params (model:: DeterministicConstantRegressor , fitresult) =
58+ (; mean= fitresult. μ)
4459
4560# #
4661# # THE CONSTANT CLASSIFIER
@@ -115,7 +130,11 @@ metadata_model(
115130metadata_model (
116131 DeterministicConstantRegressor,
117132 input_scitype = Table,
118- target_scitype = AbstractVector{Continuous},
133+ target_scitype = Union{
134+ AbstractMatrix{Continuous},
135+ AbstractVector{Continuous},
136+ Table,
137+ },
119138 supports_weights = false ,
120139 load_path = " MLJModels.DeterministicConstantRegressor"
121140)
@@ -150,6 +169,9 @@ mean or median values instead. If not specified, a normal distribution is fit.
150169Almost any reasonable model is expected to outperform `ConstantRegressor` which is used
151170almost exclusively for testing and establishing performance baselines.
152171
172+ If you need a multitarget dummy regressor, consider using `DeterministicConstantRegressor`
173+ instead.
174+
153175In MLJ (or MLJModels) do `model = ConstantRegressor()` or `model =
154176ConstantRegressor(distribution=...)` to construct a model instance.
155177
@@ -211,6 +233,67 @@ See also
211233"""
212234ConstantRegressor
213235
236+ """
237+ DeterministicConstantRegressor
238+
239+ This "dummy" predictor always makes the same prediction, irrespective of the provided
240+ input pattern, namely the mean value of the training target values. (It's counterpart,
241+ `ConstantRegressor` makes probabilistic predictions.) This model handles mutlitargets,
242+ i.e, the training target can be a matrix or a table (observations the rows).
243+
244+ Almost any reasonable model is expected to outperform `DeterministicConstantRegressor`
245+ which is used almost exclusively for testing and establishing performance baselines.
246+
247+ In MLJ, do `model = DeterministicConstantRegressor()` to construct a model instance.
248+
249+
250+ # Training data
251+
252+ In MLJ (or MLJBase) bind an instance `model` to data with
253+
254+ mach = machine(model, X, y)
255+
256+ Here:
257+
258+ - `X` is any table of input features (eg, a `DataFrame`)
259+
260+ - `y` is the target, which can be any `AbstractVector`, `AbstractVector` or table whose
261+ element scitype is `Continuous`; check the scitype `scitype(y)` or, for tables, with
262+ `schema(y)`
263+
264+ Train the machine using `fit!(mach, rows=...)`.
265+
266+ # Operations
267+
268+ - `predict(mach, Xnew)`: Return predictions of the target given
269+ features `Xnew` (which for this model are ignored).
270+
271+ # Fitted parameters
272+
273+ The fields of `fitted_params(mach)` are:
274+
275+ - `mean`: The target mean(s). Always a row vector.
276+
277+ # Examples
278+
279+ ```julia
280+ using MLJ
281+
282+ X, y = make_regression(10, 2; n_targets=3) # synthetic data: two tables
283+ regressor = DeterministicConstantRegressor()
284+ mach = machine(regressor, X, y) |> fit!
285+
286+ fitted_params(mach)
287+
288+ Xnew, _ = make_regression(3, 2)
289+ predict(mach, Xnew)
290+
291+ ```
292+ See also
293+ [`ConstantClassifier`](@ref)
294+ """
295+ DeterministicConstantRegressor
296+
214297"""
215298 ConstantClassifier
216299
0 commit comments