Skip to content

Commit 1c0957b

Browse files
committed
More refactoring
1 parent e9e18a3 commit 1c0957b

File tree

3 files changed

+27
-28
lines changed

3 files changed

+27
-28
lines changed

src/interface.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
StatsLearnModel(model, invars, outvars)
6+
StatsLearnModel(model, features, targets)
77
8-
Wrap a (possibly external) `model` with selectors of
9-
input variables `invars` and output variables `outvars`.
8+
Wrap a (possibly external) `model` with selectors
9+
of `features` and `targets`.
1010
1111
## Examples
1212
@@ -17,11 +17,11 @@ StatsLearnModel(DecisionTreeClassifier(), 1:3, "target")
1717
"""
1818
struct StatsLearnModel{M,I<:ColumnSelector,O<:ColumnSelector}
1919
model::M
20-
invars::I
21-
outvars::O
20+
feats::I
21+
targs::O
2222
end
2323

24-
StatsLearnModel(model, invars, outvars) = StatsLearnModel(model, selector(invars), selector(outvars))
24+
StatsLearnModel(model, feats, targs) = StatsLearnModel(model, selector(feats), selector(targs))
2525

2626
"""
2727
fit(model, input, output)
@@ -34,8 +34,8 @@ function fit end
3434

3535
function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
3636
println(io, "StatsLearnModel{$(nameof(M))}")
37-
println(io, "├─ features: $(model.invars)")
38-
print(io, "└─ targets: $(model.outvars)")
37+
println(io, "├─ features: $(model.feats)")
38+
print(io, "└─ targets: $(model.targs)")
3939
end
4040

4141
"""

src/learn.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
Learn(train, model, invars => outvars)
6+
Learn(train, model, features => targets)
77
88
Fits the statistical learning `model` to `train` table,
9-
using the selectors of input variables `invars` and
10-
output variables `outvars`.
9+
using the selectors of `features` and `targets`.
1110
1211
# Examples
1312
@@ -21,10 +20,10 @@ Learn(train, model, r"[abc]" => ["d", "e"])
2120
"""
2221
struct Learn{M<:FittedStatsLearnModel} <: StatelessFeatureTransform
2322
model::M
24-
invars::Vector{Symbol}
23+
feats::Vector{Symbol}
2524
end
2625

27-
Learn(train, model, (invars, outvars)::Pair) = Learn(train, StatsLearnModel(model, invars, outvars))
26+
Learn(train, model, (feats, targs)::Pair) = Learn(train, StatsLearnModel(model, feats, targs))
2827

2928
function Learn(train, lmodel::StatsLearnModel)
3029
if !Tables.istable(train)
@@ -33,22 +32,22 @@ function Learn(train, lmodel::StatsLearnModel)
3332

3433
cols = Tables.columns(train)
3534
names = Tables.columnnames(cols)
36-
invars = lmodel.invars(names)
37-
outvars = lmodel.outvars(names)
35+
feats = lmodel.feats(names)
36+
targs = lmodel.targs(names)
3837

39-
input = (; (var => Tables.getcolumn(cols, var) for var in invars)...)
40-
output = (; (var => Tables.getcolumn(cols, var) for var in outvars)...)
38+
input = (; (var => Tables.getcolumn(cols, var) for var in feats)...)
39+
output = (; (var => Tables.getcolumn(cols, var) for var in targs)...)
4140

4241
fmodel = fit(lmodel.model, input, output)
4342

44-
Learn(fmodel, invars)
43+
Learn(fmodel, feats)
4544
end
4645

4746
isrevertible(::Type{<:Learn}) = false
4847

4948
function applyfeat(transform::Learn, feat, prep)
5049
model = transform.model
51-
vars = transform.invars
50+
vars = transform.feats
5251

5352
cols = Tables.columns(feat)
5453
pairs = (var => Tables.getcolumn(cols, var) for var in vars)

test/runtests.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ const SLM = StatsLearnModels
2727
@testset "StatsLearnModel" begin
2828
# accessor functions
2929
model = DecisionTreeClassifier()
30-
invars = selector([:a, :b])
31-
outvars = selector(:c)
32-
lmodel = SLM.StatsLearnModel(model, invars, outvars)
30+
feats = selector([:a, :b])
31+
targs = selector(:c)
32+
lmodel = SLM.StatsLearnModel(model, feats, targs)
3333
@test lmodel.model === model
34-
@test lmodel.invars === invars
35-
@test lmodel.outvars === outvars
34+
@test lmodel.feats === feats
35+
@test lmodel.targs === targs
3636
# show
3737
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
3838
@test sprint(show, lmodel) == """
@@ -106,17 +106,17 @@ const SLM = StatsLearnModels
106106
@testset "Learn" begin
107107
Random.seed!(123)
108108
outvar = :target
109-
invars = setdiff(propertynames(iris), [outvar])
110-
outvars = outvar
109+
feats = setdiff(propertynames(iris), [outvar])
110+
targs = outvar
111111
model = DecisionTreeClassifier()
112-
transform = Learn(iris[train, :], model, invars => outvars)
112+
transform = Learn(iris[train, :], model, feats => targs)
113113
@test !isrevertible(transform)
114114
pred = transform(iris[test, :])
115115
accuracy = count(pred.target .== iris.target[test]) / length(test)
116116
@test accuracy > 0.9
117117

118118
# throws
119119
# training data is not a table
120-
@test_throws ArgumentError Learn(nothing, model, invars => outvars)
120+
@test_throws ArgumentError Learn(nothing, model, feats => targs)
121121
end
122122
end

0 commit comments

Comments
 (0)