Skip to content

Commit 43c28ec

Browse files
committed
extend DeterministicConstantRegressor to multitargets
1 parent 7384cde commit 43c28ec

File tree

2 files changed

+123
-7
lines changed

2 files changed

+123
-7
lines changed

src/builtins/Constant.jl

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,38 @@ MLJModelInterface.predict(::ConstantRegressor, fitresult, Xnew) =
2424
fill(fitresult, nrows(Xnew))
2525

2626
##
27-
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
27+
## THE CONSTANT DETERMINISTIC REGRESSOR
2828
##
2929

30+
# helpers:
31+
_mean(y) = _mean(y, scitype(y))
32+
_mean(y, ::Type{<:AbstractArray}) = mean(y, dims=1)
33+
_mean(y, ::Type{<:Table}) = _mean(Tables.matrix(y), AbstractArray)
34+
_materializer(y) = _materializer(y, scitype(y))
35+
_materializer(y, ::Type{<:AbstractMatrix}) = identity
36+
_materializer(y, ::Type{<:AbstractVector}) = vec
37+
_materializer(y, ::Type{<:Table}) = Tables.materializer(y)Tables.table
38+
3039
struct DeterministicConstantRegressor <: Deterministic end
3140

3241
function MLJModelInterface.fit(::DeterministicConstantRegressor,
3342
verbosity::Int,
3443
X,
3544
y)
36-
fitresult = mean(y)
45+
μ = _mean(y)
46+
materializer = _materializer(y)
47+
fitresult = (; μ, materializer)
3748
cache = nothing
3849
report = NamedTuple()
3950
return fitresult, cache, report
4051
end
4152

4253
MLJModelInterface.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
43-
fill(fitresult, nrows(Xnew))
54+
hcat([fill(fitresult.μ[i], nrows(Xnew)) for i in eachindex(fitresult.μ)]...) |>
55+
fitresult.materializer
56+
57+
MLJModelInterface.fitted_params(model::DeterministicConstantRegressor, fitresult) =
58+
(; mean=fitresult.μ)
4459

4560
##
4661
## THE CONSTANT CLASSIFIER
@@ -115,7 +130,11 @@ metadata_model(
115130
metadata_model(
116131
DeterministicConstantRegressor,
117132
input_scitype = Table,
118-
target_scitype = AbstractVector{Continuous},
133+
target_scitype = Union{
134+
AbstractMatrix{Continuous},
135+
AbstractVector{Continuous},
136+
Table,
137+
},
119138
supports_weights = false,
120139
load_path = "MLJModels.DeterministicConstantRegressor"
121140
)
@@ -150,6 +169,9 @@ mean or median values instead. If not specified, a normal distribution is fit.
150169
Almost any reasonable model is expected to outperform `ConstantRegressor` which is used
151170
almost exclusively for testing and establishing performance baselines.
152171
172+
If you need a multitarget dummy regressor, consider using `DeterministicConstantRegressor`
173+
instead.
174+
153175
In MLJ (or MLJModels) do `model = ConstantRegressor()` or `model =
154176
ConstantRegressor(distribution=...)` to construct a model instance.
155177
@@ -211,6 +233,67 @@ See also
211233
"""
212234
ConstantRegressor
213235

236+
"""
237+
DeterministicConstantRegressor
238+
239+
This "dummy" predictor always makes the same prediction, irrespective of the provided
240+
input pattern, namely the mean value of the training target values. (It's counterpart,
241+
`ConstantRegressor` makes probabilistic predictions.) This model handles mutlitargets,
242+
i.e, the training target can be a matrix or a table (observations the rows).
243+
244+
Almost any reasonable model is expected to outperform `DeterministicConstantRegressor`
245+
which is used almost exclusively for testing and establishing performance baselines.
246+
247+
In MLJ, do `model = DeterministicConstantRegressor()` to construct a model instance.
248+
249+
250+
# Training data
251+
252+
In MLJ (or MLJBase) bind an instance `model` to data with
253+
254+
mach = machine(model, X, y)
255+
256+
Here:
257+
258+
- `X` is any table of input features (eg, a `DataFrame`)
259+
260+
- `y` is the target, which can be any `AbstractVector`, `AbstractVector` or table whose
261+
element scitype is `Continuous`; check the scitype `scitype(y)` or, for tables, with
262+
`schema(y)`
263+
264+
Train the machine using `fit!(mach, rows=...)`.
265+
266+
# Operations
267+
268+
- `predict(mach, Xnew)`: Return predictions of the target given
269+
features `Xnew` (which for this model are ignored).
270+
271+
# Fitted parameters
272+
273+
The fields of `fitted_params(mach)` are:
274+
275+
- `mean`: The target mean(s). Always a row vector.
276+
277+
# Examples
278+
279+
```julia
280+
using MLJ
281+
282+
X, y = make_regression(10, 2; n_targets=3) # synthetic data: two tables
283+
regressor = DeterministicConstantRegressor()
284+
mach = machine(regressor, X, y) |> fit!
285+
286+
fitted_params(mach)
287+
288+
Xnew, _ = make_regression(3, 2)
289+
predict(mach, Xnew)
290+
291+
```
292+
See also
293+
[`ConstantClassifier`](@ref)
294+
"""
295+
DeterministicConstantRegressor
296+
214297
"""
215298
ConstantClassifier
216299

test/builtins/Constant.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module TestConstant
22

33
using Test, MLJModels, CategoricalArrays
4-
import Distributions, MLJBase
4+
import Distributions, MLJBase, Tables
5+
56

67
# Any X will do for constant models:
78
X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))
89

9-
@testset "Regressor" begin
10+
@testset "ConstantRegressor" begin
1011
y = [1.0, 1.0, 2.0, 2.0]
1112

1213
model = ConstantRegressor(distribution_type=
@@ -25,7 +26,39 @@ X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))
2526
@test MLJBase.load_path(model) == "MLJModels.ConstantRegressor"
2627
end
2728

28-
@testset "Classifier" begin
29+
@testset "DeterministicConstantRegressor" begin
30+
31+
X = (; x=ones(3))
32+
S = MLJBase.target_scitype(DeterministicConstantRegressor())
33+
34+
# vector target:
35+
y = Float64[2, 3, 4]
36+
@test MLJBase.scitype(y) <: S
37+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
38+
MLJBase.fit!(mach, verbosity=0)
39+
@test MLJBase.predict(mach, X) [3, 3, 3]
40+
@test only(MLJBase.fitted_params(mach).mean) 3
41+
42+
# matrix target:
43+
y = Float64[2 5; 3 6; 4 7]
44+
@test MLJBase.scitype(y) <: S
45+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
46+
MLJBase.fit!(mach, verbosity=0)
47+
@test MLJBase.predict(mach, X) [3 6; 3 6; 3 6]
48+
@test MLJBase.fitted_params(mach).mean [3 6]
49+
50+
# tabular target:
51+
y = Float64[2 5; 3 6; 4 7] |> Tables.table |> Tables.rowtable
52+
@test MLJBase.scitype(y) <: S
53+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
54+
MLJBase.fit!(mach, verbosity=0)
55+
yhat = MLJBase.predict(mach, X)
56+
@test yhat isa Vector{<:NamedTuple}
57+
@test Tables.matrix(yhat) [3 6; 3 6; 3 6]
58+
@test MLJBase.fitted_params(mach).mean [3 6]
59+
end
60+
61+
@testset "ConstantClassifier" begin
2962
yraw = ["Perry", "Antonia", "Perry", "Skater"]
3063
y = categorical(yraw)
3164

0 commit comments

Comments
 (0)