Skip to content

Commit 36ad9b2

Browse files
committed
ensure predicted multitargets have the right column names
1 parent 43c28ec commit 36ad9b2

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/builtins/Constant.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ _mean(y, ::Type{<:Table}) = _mean(Tables.matrix(y), AbstractArray)
3434
_materializer(y) = _materializer(y, scitype(y))
3535
_materializer(y, ::Type{<:AbstractMatrix}) = identity
3636
_materializer(y, ::Type{<:AbstractVector}) = vec
37-
_materializer(y, ::Type{<:Table}) = Tables.materializer(y)Tables.table
37+
function _materializer(y, ::Type{<:Table})
38+
names = Tables.columnnames(Tables.columntable(y))
39+
Tables.materializer(y)(matrix->Tables.table(matrix; header=names))
40+
end
3841

3942
struct DeterministicConstantRegressor <: Deterministic end
4043

@@ -279,9 +282,9 @@ The fields of `fitted_params(mach)` are:
279282
```julia
280283
using MLJ
281284
282-
X, y = make_regression(10, 2; n_targets=3) # synthetic data: two tables
283-
regressor = DeterministicConstantRegressor()
284-
mach = machine(regressor, X, y) |> fit!
285+
X, y = make_regression(10, 2; n_targets=3); # synthetic data: two tables
286+
regressor = DeterministicConstantRegressor();
287+
mach = machine(regressor, X, y) |> fit!;
285288
286289
fitted_params(mach)
287290

test/builtins/Constant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ end
4848
@test MLJBase.fitted_params(mach).mean [3 6]
4949

5050
# tabular target:
51-
y = Float64[2 5; 3 6; 4 7] |> Tables.table |> Tables.rowtable
51+
y = Tables.table(Float64[2 5; 3 6; 4 7], header=[:x, :y]) |> Tables.rowtable
5252
@test MLJBase.scitype(y) <: S
5353
mach = MLJBase.machine(MLJModels.DeterministicConstantRegressor(), X, y)
5454
MLJBase.fit!(mach, verbosity=0)
5555
yhat = MLJBase.predict(mach, X)
5656
@test yhat isa Vector{<:NamedTuple}
57+
@test keys(yhat[1]) == (:x, :y)
5758
@test Tables.matrix(yhat) [3 6; 3 6; 3 6]
5859
@test MLJBase.fitted_params(mach).mean [3 6]
5960
end

0 commit comments

Comments
 (0)