Skip to content

Commit 9825498

Browse files
committed
Initial commit
1 parent 4af79ae commit 9825498

File tree

6 files changed

+67
-38
lines changed

6 files changed

+67
-38
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.jl.*.cov
22
*.jl.cov
33
*.jl.mem
4-
/Manifest.toml
4+
Manifest.toml
5+
.vscode

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
name = "StatsLearnModels"
22
uuid = "c146b59d-1589-421c-8e09-a22e554fd05c"
33
authors = ["Elias Carvalho <[email protected]> and contributors"]
4-
version = "1.0.0-DEV"
4+
version = "0.1.0"
5+
6+
[deps]
7+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
8+
9+
[weakdeps]
10+
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
11+
12+
[extensions]
13+
StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
514

615
[compat]
16+
MLJModelInterface = "1.9"
17+
Tables = "1.11"
718
julia = "1.9"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# StatsLearnModels
1+
# [WIP] StatsLearnModels.jl
22

33
[![Build Status](https://github.com/JuliaML/StatsLearnModels.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/JuliaML/StatsLearnModels.jl/actions/workflows/CI.yml?query=branch%3Amain)
44
[![Coverage](https://codecov.io/gh/JuliaML/StatsLearnModels.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/JuliaML/StatsLearnModels.jl)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
module StatsLearnModelsMLJModelInterfaceExt
2+
3+
using Tables
4+
import StatsLearnModels as SLM
5+
import MLJModelInterface as MI
6+
7+
isprobabilistic(model::MI.Model) = MI.prediction_type(model) == :probabilistic
8+
isprobabilistic(model::MI.Probabilistic) = true
9+
10+
function SLM.fit(model::MI.Model, input, output)
11+
cols = Tables.columns(output)
12+
names = Tables.columnnames(cols)
13+
y = Tables.getcolumn(cols, first(names))
14+
data = MI.reformat(model, input, y)
15+
fitresult, _... = MI.fit(model, 0, data...)
16+
SLM.FittedModel(model, fitresult)
17+
end
18+
19+
function SLM.predict(fmodel::SLM.FittedModel{<:MI.Model}, table)
20+
(; model, fitresult) = fmodel
21+
data = MI.reformat(model, table)
22+
if isprobabilistic(model)
23+
MI.predict_mode(model, fitresult, data...)
24+
else
25+
MI.predict(model, fitresult, data...)
26+
end
27+
end
28+
29+
end

src/StatsLearnModels.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
module StatsLearnModels
22

3-
# Write your package code here.
3+
"""
4+
StatsLearnModels.fit(model, input, output) -> FittedModel
5+
6+
TODO
7+
"""
8+
function fit end
9+
10+
"""
11+
StatsLearnModels.predict(model::FittedModel, table)
12+
13+
TODO
14+
"""
15+
function predict end
16+
17+
"""
18+
StatsLearnModels.FittedModel(model, fitresult)
19+
20+
TODO
21+
"""
22+
struct FittedModel{M,F}
23+
model::M
24+
fitresult::F
25+
end
426

527
end

test/Manifest.toml

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)