Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 87 additions & 4 deletions src/builtins/Constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,38 @@ MLJModelInterface.predict(::ConstantRegressor, fitresult, Xnew) =
fill(fitresult, nrows(Xnew))

##
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
## THE CONSTANT DETERMINISTIC REGRESSOR
##

# helpers:
_mean(y) = _mean(y, scitype(y))
_mean(y, ::Type{<:AbstractArray}) = mean(y, dims=1)
_mean(y, ::Type{<:Table}) = _mean(Tables.matrix(y), AbstractArray)
_materializer(y) = _materializer(y, scitype(y))
_materializer(y, ::Type{<:AbstractMatrix}) = identity
_materializer(y, ::Type{<:AbstractVector}) = vec
_materializer(y, ::Type{<:Table}) = Tables.materializer(y)∘Tables.table

struct DeterministicConstantRegressor <: Deterministic end

function MLJModelInterface.fit(::DeterministicConstantRegressor,
verbosity::Int,
X,
y)
fitresult = mean(y)
μ = _mean(y)
materializer = _materializer(y)
fitresult = (; μ, materializer)
cache = nothing
report = NamedTuple()
return fitresult, cache, report
end

MLJModelInterface.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
fill(fitresult, nrows(Xnew))
hcat([fill(fitresult.μ[i], nrows(Xnew)) for i in eachindex(fitresult.μ)]...) |>
fitresult.materializer

MLJModelInterface.fitted_params(model::DeterministicConstantRegressor, fitresult) =
(; mean=fitresult.μ)

##
## THE CONSTANT CLASSIFIER
Expand Down Expand Up @@ -115,7 +130,11 @@ metadata_model(
metadata_model(
DeterministicConstantRegressor,
input_scitype = Table,
target_scitype = AbstractVector{Continuous},
target_scitype = Union{
AbstractMatrix{Continuous},
AbstractVector{Continuous},
Table,
},
supports_weights = false,
load_path = "MLJModels.DeterministicConstantRegressor"
)
Expand Down Expand Up @@ -150,6 +169,9 @@ mean or median values instead. If not specified, a normal distribution is fit.
Almost any reasonable model is expected to outperform `ConstantRegressor` which is used
almost exclusively for testing and establishing performance baselines.

If you need a multitarget dummy regressor, consider using `DeterministicConstantRegressor`
instead.

In MLJ (or MLJModels) do `model = ConstantRegressor()` or `model =
ConstantRegressor(distribution=...)` to construct a model instance.

Expand Down Expand Up @@ -211,6 +233,67 @@ See also
"""
ConstantRegressor

"""
DeterministicConstantRegressor

This "dummy" predictor always makes the same prediction, irrespective of the provided
input pattern, namely the mean value of the training target values. (It's counterpart,
`ConstantRegressor` makes probabilistic predictions.) This model handles mutlitargets,
i.e, the training target can be a matrix or a table (observations the rows).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems likes there is a typo here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, where?


Almost any reasonable model is expected to outperform `DeterministicConstantRegressor`
which is used almost exclusively for testing and establishing performance baselines.

In MLJ, do `model = DeterministicConstantRegressor()` to construct a model instance.


# Training data

In MLJ (or MLJBase) bind an instance `model` to data with

mach = machine(model, X, y)

Here:

- `X` is any table of input features (eg, a `DataFrame`)

- `y` is the target, which can be any `AbstractVector`, `AbstractVector` or table whose
element scitype is `Continuous`; check the scitype `scitype(y)` or, for tables, with
`schema(y)`

Train the machine using `fit!(mach, rows=...)`.

# Operations

- `predict(mach, Xnew)`: Return predictions of the target given
features `Xnew` (which for this model are ignored).

# Fitted parameters

The fields of `fitted_params(mach)` are:

- `mean`: The target mean(s). Always a row vector.

# Examples

```julia
using MLJ

X, y = make_regression(10, 2; n_targets=3) # synthetic data: two tables
regressor = DeterministicConstantRegressor()
mach = machine(regressor, X, y) |> fit!

fitted_params(mach)

Xnew, _ = make_regression(3, 2)
predict(mach, Xnew)

```
See also
[`ConstantClassifier`](@ref)
"""
DeterministicConstantRegressor

"""
ConstantClassifier

Expand Down
39 changes: 36 additions & 3 deletions test/builtins/Constant.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
module TestConstant

using Test, MLJModels, CategoricalArrays
import Distributions, MLJBase
import Distributions, MLJBase, Tables


# Any X will do for constant models:
X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))

@testset "Regressor" begin
@testset "ConstantRegressor" begin
y = [1.0, 1.0, 2.0, 2.0]

model = ConstantRegressor(distribution_type=
Expand All @@ -25,7 +26,39 @@ X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))
@test MLJBase.load_path(model) == "MLJModels.ConstantRegressor"
end

@testset "Classifier" begin
@testset "DeterministicConstantRegressor" begin

X = (; x=ones(3))
S = MLJBase.target_scitype(DeterministicConstantRegressor())

# vector target:
y = Float64[2, 3, 4]
@test MLJBase.scitype(y) <: S
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
MLJBase.fit!(mach, verbosity=0)
@test MLJBase.predict(mach, X) ≈ [3, 3, 3]
@test only(MLJBase.fitted_params(mach).mean) ≈ 3

# matrix target:
y = Float64[2 5; 3 6; 4 7]
@test MLJBase.scitype(y) <: S
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
MLJBase.fit!(mach, verbosity=0)
@test MLJBase.predict(mach, X) ≈ [3 6; 3 6; 3 6]
@test MLJBase.fitted_params(mach).mean ≈ [3 6]

# tabular target:
y = Float64[2 5; 3 6; 4 7] |> Tables.table |> Tables.rowtable
@test MLJBase.scitype(y) <: S
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
MLJBase.fit!(mach, verbosity=0)
yhat = MLJBase.predict(mach, X)
@test yhat isa Vector{<:NamedTuple}
@test Tables.matrix(yhat) ≈ [3 6; 3 6; 3 6]
@test MLJBase.fitted_params(mach).mean ≈ [3 6]
end

@testset "ConstantClassifier" begin
yraw = ["Perry", "Antonia", "Perry", "Skater"]
y = categorical(yraw)

Expand Down
Loading