Skip to content

Commit 853591a

Browse files
committed
Add docstrings
1 parent 9825498 commit 853591a

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

ext/StatsLearnModelsMLJModelInterfaceExt.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,23 @@ isprobabilistic(model::MI.Probabilistic) = true
1010
function SLM.fit(model::MI.Model, input, output)
1111
cols = Tables.columns(output)
1212
names = Tables.columnnames(cols)
13-
y = Tables.getcolumn(cols, first(names))
13+
target = first(names)
14+
y = Tables.getcolumn(cols, target)
1415
data = MI.reformat(model, input, y)
1516
fitresult, _... = MI.fit(model, 0, data...)
16-
SLM.FittedModel(model, fitresult)
17+
SLM.FittedModel(model, (fitresult, target))
1718
end
1819

1920
function SLM.predict(fmodel::SLM.FittedModel{<:MI.Model}, table)
20-
(; model, fitresult) = fmodel
21+
(; model, cache) = fmodel
22+
fitresult, target = cache
2123
data = MI.reformat(model, table)
22-
if isprobabilistic(model)
24+
= if isprobabilistic(model)
2325
MI.predict_mode(model, fitresult, data...)
2426
else
2527
MI.predict(model, fitresult, data...)
2628
end
29+
(; target => ŷ)
2730
end
2831

2932
end

src/StatsLearnModels.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,29 @@ module StatsLearnModels
33
"""
44
StatsLearnModels.fit(model, input, output) -> FittedModel
55
6-
TODO
6+
Fit statistical learning `model` using features in `input` table
7+
and targets in `output` table. Returns a fitted model with all
8+
the necessary information for prediction with the `predict` function.
79
"""
810
function fit end
911

1012
"""
1113
StatsLearnModels.predict(model::FittedModel, table)
1214
13-
TODO
15+
Predict the target values using the fitted statistical learning `model`
16+
and a new `table` of features.
1417
"""
1518
function predict end
1619

1720
"""
18-
StatsLearnModels.FittedModel(model, fitresult)
21+
StatsLearnModels.FittedModel(model, cache)
1922
20-
TODO
23+
Wrapper type used to save learning model and auxiliary
24+
variables needed for prediction.
2125
"""
22-
struct FittedModel{M,F}
26+
struct FittedModel{M,C}
2327
model::M
24-
fitresult::F
28+
cache::C
2529
end
2630

2731
end

0 commit comments

Comments
 (0)