@@ -24,23 +24,41 @@ MLJModelInterface.predict(::ConstantRegressor, fitresult, Xnew) =
2424 fill (fitresult, nrows (Xnew))
2525
2626# #
27- # # THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
27+ # # THE CONSTANT DETERMINISTIC REGRESSOR
2828# #
2929
30+ # helpers:
31+ _mean (y) = _mean (y, scitype (y))
32+ _mean (y, :: Type{<:AbstractArray} ) = mean (y, dims= 1 )
33+ _mean (y, :: Type{<:Table} ) = _mean (Tables. matrix (y), AbstractArray)
34+ _materializer (y) = _materializer (y, scitype (y))
35+ _materializer (y, :: Type{<:AbstractMatrix} ) = identity
36+ _materializer (y, :: Type{<:AbstractVector} ) = vec
37+ function _materializer (y, :: Type{<:Table} )
38+ names = Tables. columnnames (Tables. columntable (y))
39+ Tables. materializer (y)∘ (matrix-> Tables. table (matrix; header= names))
40+ end
41+
3042struct DeterministicConstantRegressor <: Deterministic end
3143
3244function MLJModelInterface. fit (:: DeterministicConstantRegressor ,
3345 verbosity:: Int ,
3446 X,
3547 y)
36- fitresult = mean (y)
48+ μ = _mean (y)
49+ materializer = _materializer (y)
50+ fitresult = (; μ, materializer)
3751 cache = nothing
3852 report = NamedTuple ()
3953 return fitresult, cache, report
4054end
4155
4256MLJModelInterface. predict (:: DeterministicConstantRegressor , fitresult, Xnew) =
43- fill (fitresult, nrows (Xnew))
57+ hcat ([fill (fitresult. μ[i], nrows (Xnew)) for i in eachindex (fitresult. μ)]. .. ) |>
58+ fitresult. materializer
59+
60+ MLJModelInterface. fitted_params (model:: DeterministicConstantRegressor , fitresult) =
61+ (; mean= fitresult. μ)
4462
4563# #
4664# # THE CONSTANT CLASSIFIER
@@ -115,7 +133,11 @@ metadata_model(
115133metadata_model (
116134 DeterministicConstantRegressor,
117135 input_scitype = Table,
118- target_scitype = AbstractVector{Continuous},
136+ target_scitype = Union{
137+ AbstractMatrix{Continuous},
138+ AbstractVector{Continuous},
139+ Table,
140+ },
119141 supports_weights = false ,
120142 load_path = " MLJModels.DeterministicConstantRegressor"
121143)
@@ -150,6 +172,9 @@ mean or median values instead. If not specified, a normal distribution is fit.
150172Almost any reasonable model is expected to outperform `ConstantRegressor` which is used
151173almost exclusively for testing and establishing performance baselines.
152174
175+ If you need a multitarget dummy regressor, consider using `DeterministicConstantRegressor`
176+ instead.
177+
153178In MLJ (or MLJModels) do `model = ConstantRegressor()` or `model =
154179ConstantRegressor(distribution=...)` to construct a model instance.
155180
@@ -211,6 +236,67 @@ See also
211236"""
212237ConstantRegressor
213238
239+ """
240+ DeterministicConstantRegressor
241+
242+ This "dummy" predictor always makes the same prediction, irrespective of the provided
243+ input pattern, namely the mean value of the training target values. (It's counterpart,
244+ `ConstantRegressor` makes probabilistic predictions.) This model handles mutlitargets,
245+ i.e, the training target can be a matrix or a table (with rows as observations).
246+
247+ Almost any reasonable model is expected to outperform `DeterministicConstantRegressor`
248+ which is used almost exclusively for testing and establishing performance baselines.
249+
250+ In MLJ, do `model = DeterministicConstantRegressor()` to construct a model instance.
251+
252+
253+ # Training data
254+
255+ In MLJ (or MLJBase) bind an instance `model` to data with
256+
257+ mach = machine(model, X, y)
258+
259+ Here:
260+
261+ - `X` is any table of input features (eg, a `DataFrame`)
262+
263+ - `y` is the target, which can be any `AbstractVector`, `AbstractVector` or table whose
264+ element scitype is `Continuous`; check the scitype `scitype(y)` or, for tables, with
265+ `schema(y)`
266+
267+ Train the machine using `fit!(mach, rows=...)`.
268+
269+ # Operations
270+
271+ - `predict(mach, Xnew)`: Return predictions of the target given
272+ features `Xnew` (which for this model are ignored).
273+
274+ # Fitted parameters
275+
276+ The fields of `fitted_params(mach)` are:
277+
278+ - `mean`: The target mean(s). Always a row vector. (i.e an `AbstractMatrix` object with row dim 1)
279+
280+ # Examples
281+
282+ ```julia
283+ using MLJ
284+
285+ X, y = make_regression(10, 2; n_targets=3); # synthetic data: two tables
286+ regressor = DeterministicConstantRegressor();
287+ mach = machine(regressor, X, y) |> fit!;
288+
289+ fitted_params(mach)
290+
291+ Xnew, _ = make_regression(3, 2)
292+ predict(mach, Xnew)
293+
294+ ```
295+ See also
296+ [`ConstantClassifier`](@ref)
297+ """
298+ DeterministicConstantRegressor
299+
214300"""
215301 ConstantClassifier
216302
0 commit comments