Skip to content

Commit be4ca78

Browse files
authored
Merge pull request #603 from JuliaAI/dev
For a 0.18.3 release
2 parents 11240ae + 2172e81 commit be4ca78

File tree

3 files changed

+128
-8
lines changed

3 files changed

+128
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJModels"
22
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.18.2"
4+
version = "0.18.3"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/builtins/Constant.jl

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,41 @@ MLJModelInterface.predict(::ConstantRegressor, fitresult, Xnew) =
2424
fill(fitresult, nrows(Xnew))
2525

2626
##
27-
## THE CONSTANT DETERMINISTIC REGRESSOR (FOR TESTING)
27+
## THE CONSTANT DETERMINISTIC REGRESSOR
2828
##
2929

30+
# helpers:
31+
_mean(y) = _mean(y, scitype(y))
32+
_mean(y, ::Type{<:AbstractArray}) = mean(y, dims=1)
33+
_mean(y, ::Type{<:Table}) = _mean(Tables.matrix(y), AbstractArray)
34+
_materializer(y) = _materializer(y, scitype(y))
35+
_materializer(y, ::Type{<:AbstractMatrix}) = identity
36+
_materializer(y, ::Type{<:AbstractVector}) = vec
37+
function _materializer(y, ::Type{<:Table})
38+
names = Tables.columnnames(Tables.columntable(y))
39+
Tables.materializer(y)(matrix->Tables.table(matrix; header=names))
40+
end
41+
3042
struct DeterministicConstantRegressor <: Deterministic end
3143

3244
function MLJModelInterface.fit(::DeterministicConstantRegressor,
3345
verbosity::Int,
3446
X,
3547
y)
36-
fitresult = mean(y)
48+
μ = _mean(y)
49+
materializer = _materializer(y)
50+
fitresult = (; μ, materializer)
3751
cache = nothing
3852
report = NamedTuple()
3953
return fitresult, cache, report
4054
end
4155

4256
MLJModelInterface.predict(::DeterministicConstantRegressor, fitresult, Xnew) =
43-
fill(fitresult, nrows(Xnew))
57+
hcat([fill(fitresult.μ[i], nrows(Xnew)) for i in eachindex(fitresult.μ)]...) |>
58+
fitresult.materializer
59+
60+
MLJModelInterface.fitted_params(model::DeterministicConstantRegressor, fitresult) =
61+
(; mean=fitresult.μ)
4462

4563
##
4664
## THE CONSTANT CLASSIFIER
@@ -115,7 +133,11 @@ metadata_model(
115133
metadata_model(
116134
DeterministicConstantRegressor,
117135
input_scitype = Table,
118-
target_scitype = AbstractVector{Continuous},
136+
target_scitype = Union{
137+
AbstractMatrix{Continuous},
138+
AbstractVector{Continuous},
139+
Table,
140+
},
119141
supports_weights = false,
120142
load_path = "MLJModels.DeterministicConstantRegressor"
121143
)
@@ -150,6 +172,9 @@ mean or median values instead. If not specified, a normal distribution is fit.
150172
Almost any reasonable model is expected to outperform `ConstantRegressor` which is used
151173
almost exclusively for testing and establishing performance baselines.
152174
175+
If you need a multitarget dummy regressor, consider using `DeterministicConstantRegressor`
176+
instead.
177+
153178
In MLJ (or MLJModels) do `model = ConstantRegressor()` or `model =
154179
ConstantRegressor(distribution=...)` to construct a model instance.
155180
@@ -211,6 +236,67 @@ See also
211236
"""
212237
ConstantRegressor
213238

239+
"""
240+
DeterministicConstantRegressor
241+
242+
This "dummy" predictor always makes the same prediction, irrespective of the provided
243+
input pattern, namely the mean value of the training target values. (It's counterpart,
244+
`ConstantRegressor` makes probabilistic predictions.) This model handles mutlitargets,
245+
i.e, the training target can be a matrix or a table (with rows as observations).
246+
247+
Almost any reasonable model is expected to outperform `DeterministicConstantRegressor`
248+
which is used almost exclusively for testing and establishing performance baselines.
249+
250+
In MLJ, do `model = DeterministicConstantRegressor()` to construct a model instance.
251+
252+
253+
# Training data
254+
255+
In MLJ (or MLJBase) bind an instance `model` to data with
256+
257+
mach = machine(model, X, y)
258+
259+
Here:
260+
261+
- `X` is any table of input features (eg, a `DataFrame`)
262+
263+
- `y` is the target, which can be any `AbstractVector`, `AbstractVector` or table whose
264+
element scitype is `Continuous`; check the scitype `scitype(y)` or, for tables, with
265+
`schema(y)`
266+
267+
Train the machine using `fit!(mach, rows=...)`.
268+
269+
# Operations
270+
271+
- `predict(mach, Xnew)`: Return predictions of the target given
272+
features `Xnew` (which for this model are ignored).
273+
274+
# Fitted parameters
275+
276+
The fields of `fitted_params(mach)` are:
277+
278+
- `mean`: The target mean(s). Always a row vector. (i.e an `AbstractMatrix` object with row dim 1)
279+
280+
# Examples
281+
282+
```julia
283+
using MLJ
284+
285+
X, y = make_regression(10, 2; n_targets=3); # synthetic data: two tables
286+
regressor = DeterministicConstantRegressor();
287+
mach = machine(regressor, X, y) |> fit!;
288+
289+
fitted_params(mach)
290+
291+
Xnew, _ = make_regression(3, 2)
292+
predict(mach, Xnew)
293+
294+
```
295+
See also
296+
[`ConstantClassifier`](@ref)
297+
"""
298+
DeterministicConstantRegressor
299+
214300
"""
215301
ConstantClassifier
216302

test/builtins/Constant.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module TestConstant
22

33
using Test, MLJModels, CategoricalArrays
4-
import Distributions, MLJBase
4+
import Distributions, MLJBase, Tables
5+
56

67
# Any X will do for constant models:
78
X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))
89

9-
@testset "Regressor" begin
10+
@testset "ConstantRegressor" begin
1011
y = [1.0, 1.0, 2.0, 2.0]
1112

1213
model = ConstantRegressor(distribution_type=
@@ -25,7 +26,40 @@ X = NamedTuple{(:x1,:x2,:x3)}((rand(10), rand(10), rand(10)))
2526
@test MLJBase.load_path(model) == "MLJModels.ConstantRegressor"
2627
end
2728

28-
@testset "Classifier" begin
29+
@testset "DeterministicConstantRegressor" begin
30+
31+
X = (; x=ones(3))
32+
S = MLJBase.target_scitype(DeterministicConstantRegressor())
33+
34+
# vector target:
35+
y = Float64[2, 3, 4]
36+
@test MLJBase.scitype(y) <: S
37+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
38+
MLJBase.fit!(mach, verbosity=0)
39+
@test MLJBase.predict(mach, X) [3, 3, 3]
40+
@test only(MLJBase.fitted_params(mach).mean) 3
41+
42+
# matrix target:
43+
y = Float64[2 5; 3 6; 4 7]
44+
@test MLJBase.scitype(y) <: S
45+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
46+
MLJBase.fit!(mach, verbosity=0)
47+
@test MLJBase.predict(mach, X) [3 6; 3 6; 3 6]
48+
@test MLJBase.fitted_params(mach).mean [3 6]
49+
50+
# tabular target:
51+
y = Tables.table(Float64[2 5; 3 6; 4 7], header=[:x, :y]) |> Tables.rowtable
52+
@test MLJBase.scitype(y) <: S
53+
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
54+
MLJBase.fit!(mach, verbosity=0)
55+
yhat = MLJBase.predict(mach, X)
56+
@test yhat isa Vector{<:NamedTuple}
57+
@test keys(yhat[1]) == (:x, :y)
58+
@test Tables.matrix(yhat) [3 6; 3 6; 3 6]
59+
@test MLJBase.fitted_params(mach).mean [3 6]
60+
end
61+
62+
@testset "ConstantClassifier" begin
2963
yraw = ["Perry", "Antonia", "Perry", "Skater"]
3064
y = categorical(yraw)
3165

0 commit comments

Comments
 (0)