Skip to content

Commit 507684f

Browse files
authored
Add StatsLearnModel (#9)
* Add 'StatsLearnModel' * Update 'Learn'
1 parent 88c84c8 commit 507684f

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

src/StatsLearnModels.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using Tables
88
using Distances
99
using DataScienceTraits
1010
using StatsBase: mode, mean
11-
using ColumnSelectors: selector
11+
using ColumnSelectors: ColumnSelector, selector
1212
using TableTransforms: StatelessFeatureTransform
1313

1414
import DataScienceTraits as DST

src/interface.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,22 @@ struct FittedModel{M,C}
3131
end
3232

3333
Base.show(io::IO, ::FittedModel{M}) where {M} = print(io, "FittedModel{$(nameof(M))}")
34+
35+
"""
36+
StatsLearnModels.StatsLearnModel(model, incols, outcols)
37+
38+
Wrapper type for learning models used for dispatch purposes.
39+
"""
40+
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
41+
model::M
42+
input::I
43+
output::O
44+
end
45+
46+
StatsLearnModel(model, incols, outcols) = StatsLearnModel(model, selector(incols), selector(outcols))
47+
48+
function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
49+
println(io, "StatsLearnModel{$(nameof(M))}")
50+
println(io, "├─ input: $(model.input)")
51+
print(io, "└─ output: $(model.output)")
52+
end

src/learn.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,22 @@ struct Learn{M<:FittedModel} <: StatelessFeatureTransform
2626
input::Vector{Symbol}
2727
end
2828

29-
function Learn(train, model, (incols, outcols)::Pair)
29+
Learn(train, model, (incols, outcols)::Pair) = Learn(train, StatsLearnModel(model, incols, outcols))
30+
31+
function Learn(train, lmodel::StatsLearnModel)
3032
if !Tables.istable(train)
3133
throw(ArgumentError("training data must be a table"))
3234
end
3335

3436
cols = Tables.columns(train)
3537
names = Tables.columnnames(cols)
36-
innms = selector(incols)(names)
37-
outnms = selector(outcols)(names)
38+
innms = lmodel.input(names)
39+
outnms = lmodel.output(names)
3840

3941
input = (; (nm => Tables.getcolumn(cols, nm) for nm in innms)...)
4042
output = (; (nm => Tables.getcolumn(cols, nm) for nm in outnms)...)
4143

42-
fmodel = fit(model, input, output)
44+
fmodel = fit(lmodel.model, input, output)
4345
Learn(fmodel, innms)
4446
end
4547

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ const SLM = StatsLearnModels
2323
@test sprint(show, fmodel) == "FittedModel{DecisionTreeClassifier}"
2424
end
2525

26+
@testset "StatsLearnModel" begin
27+
model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
28+
@test sprint(show, model) == """
29+
StatsLearnModel{DecisionTreeClassifier}
30+
├─ input: [:a, :b]
31+
└─ output: :c"""
32+
end
33+
2634
@testset "models" begin
2735
@testset "MLJ" begin
2836
Random.seed!(123)

0 commit comments

Comments
 (0)