Skip to content

Commit fe0fda9

Browse files
authored
Add accessor functions (#10)
1 parent 507684f commit fe0fda9

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

src/interface.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,24 @@ function Base.show(io::IO, model::StatsLearnModel{M}) where {M}
5050
println(io, "├─ input: $(model.input)")
5151
print(io, "└─ output: $(model.output)")
5252
end
53+
54+
"""
55+
StatsLearnModels.model(lmodel::StatsLearnModel)
56+
57+
Returns the model of the `lmodel`.
58+
"""
59+
model(lmodel::StatsLearnModel) = lmodel.model
60+
61+
"""
62+
StatsLearnModels.input(lmodel::StatsLearnModel)
63+
64+
Returns the input column selection of the `lmodel`.
65+
"""
66+
input(lmodel::StatsLearnModel) = lmodel.input
67+
68+
"""
69+
StatsLearnModels.output(lmodel::StatsLearnModel)
70+
71+
Returns the output column selection of the `lmodel`.
72+
"""
73+
output(lmodel::StatsLearnModel) = lmodel.output

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
23
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
34
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
45
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"

test/runtests.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Test
66

77
using GLM: ProbitLink
88
using Distributions: Binomial
9+
using ColumnSelectors: selector
910

1011
import MLJ, MLJDecisionTreeInterface
1112

@@ -24,8 +25,17 @@ const SLM = StatsLearnModels
2425
end
2526

2627
@testset "StatsLearnModel" begin
27-
model = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
28-
@test sprint(show, model) == """
28+
# accessor functions
29+
model = DecisionTreeClassifier()
30+
incols = selector([:a, :b])
31+
outcols = selector(:c)
32+
lmodel = SLM.StatsLearnModel(model, incols, outcols)
33+
@test SLM.model(lmodel) === model
34+
@test SLM.input(lmodel) === incols
35+
@test SLM.output(lmodel) === outcols
36+
# show
37+
lmodel = SLM.StatsLearnModel(DecisionTreeClassifier(), [:a, :b], :c)
38+
@test sprint(show, lmodel) == """
2939
StatsLearnModel{DecisionTreeClassifier}
3040
├─ input: [:a, :b]
3141
└─ output: :c"""

0 commit comments

Comments
 (0)