Skip to content

Commit 7849483

Browse files
authored
Redesign the 'Learn' transform syntax (#3)
1 parent ed2d8f9 commit 7849483

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/learn.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
Learn(model, train, incols, outcols)
6+
Learn(train, model, incols => outcols)
77
88
Fits the statistical learning `model` using the input columns, selected by `incols`,
99
and the output columns, selected by `outcols`, from the `train` table.
@@ -14,19 +14,19 @@ a collection of identifiers or a regular expression (regex).
1414
# Examples
1515
1616
```julia
17-
Learn(model, train, [1, 2, 3], "d")
18-
Learn(model, train, [:a, :b, :c], :d)
19-
Learn(model, train, ["a", "b", "c"], 4)
20-
Learn(model, train, [1, 2, 3], [:d, :e])
21-
Learn(model, train, r"[abc]", ["d", "e"])
17+
Learn(train, model, [1, 2, 3] => "d")
18+
Learn(train, model, [:a, :b, :c] => :d)
19+
Learn(train, model, ["a", "b", "c"] => 4)
20+
Learn(train, model, [1, 2, 3] => [:d, :e])
21+
Learn(train, model, r"[abc]" => ["d", "e"])
2222
```
2323
"""
2424
struct Learn{M<:FittedModel} <: StatelessFeatureTransform
2525
model::M
2626
input::Vector{Symbol}
2727
end
2828

29-
function Learn(model, train, incols, outcols)
29+
function Learn(train, model, (incols, outcols)::Pair)
3030
if !Tables.istable(train)
3131
throw(ArgumentError("training data must be a table"))
3232
end

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ const SLM = StatsLearnModels
1313
iris = DataFrame(load_iris())
1414
input = iris[:, Not(:target)]
1515
output = iris[:, [:target]]
16-
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
1716
train, test = partition(1:nrow(input), 0.7, rng=123)
17+
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
1818
fmodel = SLM.fit(Tree(), input[train, :], output[train, :])
1919
pred = SLM.predict(fmodel, input[test, :])
2020
accuracy = count(pred.target .== output.target[test]) / length(test)
@@ -26,16 +26,16 @@ const SLM = StatsLearnModels
2626
iris = DataFrame(load_iris())
2727
outcol = :target
2828
incols = setdiff(propertynames(iris), [outcol])
29-
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
3029
train, test = partition(1:nrow(iris), 0.7, rng=123)
31-
transform = Learn(Tree(), iris[train, :], incols, outcol)
30+
Tree = @load(DecisionTreeClassifier, pkg=DecisionTree, verbosity=0)
31+
transform = Learn(iris[train, :], Tree(), incols => outcol)
3232
@test !isrevertible(transform)
3333
pred = transform(iris[test, :])
3434
accuracy = count(pred.target .== iris.target[test]) / length(test)
3535
@test accuracy > 0.9
3636

3737
# throws
3838
# training data is not a table
39-
@test_throws ArgumentError Learn(Tree(), nothing, incols, outcol)
39+
@test_throws ArgumentError Learn(nothing, Tree(), incols => outcol)
4040
end
4141
end

0 commit comments

Comments
 (0)