Skip to content

Commit 7655ca6

Browse files
authored
Add NearestNeighbors.jl models (#6)
* Add NearestNeighbors.jl models * Fix tests
1 parent bf2674e commit 7655ca6

File tree

6 files changed

+131
-16
lines changed

6 files changed

+131
-16
lines changed

Project.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@ version = "0.2.1"
55

66
[deps]
77
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
8+
DataScienceTraits = "6cb2f572-2d2b-4ba6-bdb3-e710fa044d6c"
89
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
10+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
911
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1012
GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
13+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
14+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1115
TableTransforms = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e"
1216
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1317

@@ -19,10 +23,14 @@ StatsLearnModelsMLJModelInterfaceExt = "MLJModelInterface"
1923

2024
[compat]
2125
ColumnSelectors = "0.1"
26+
DataScienceTraits = "0.1"
2227
DecisionTree = "0.12"
28+
Distances = "0.10"
2329
Distributions = "0.25"
2430
GLM = "1.9"
2531
MLJModelInterface = "1.9"
32+
NearestNeighbors = "0.4"
33+
StatsBase = "0.34"
2634
TableTransforms = "1.15"
2735
Tables = "1.11"
2836
julia = "1.9"

src/StatsLearnModels.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,28 @@
55
module StatsLearnModels
66

77
using Tables
8+
using Distances
9+
using DataScienceTraits
10+
using StatsBase: mode, mean
811
using ColumnSelectors: selector
912
using TableTransforms: StatelessFeatureTransform
13+
14+
import DataScienceTraits as DST
1015
import TableTransforms: applyfeat, isrevertible
1116

12-
import GLM
13-
import DecisionTree as DT
1417
using DecisionTree: AdaBoostStumpClassifier, DecisionTreeClassifier, RandomForestClassifier
1518
using DecisionTree: DecisionTreeRegressor, RandomForestRegressor
1619
using Distributions: UnivariateDistribution
20+
using NearestNeighbors: MinkowskiMetric
21+
22+
import GLM
23+
import DecisionTree as DT
24+
import NearestNeighbors as NN
1725

1826
include("interface.jl")
19-
include("models/decisiontree.jl")
2027
include("models/glm.jl")
28+
include("models/decisiontree.jl")
29+
include("models/nearestneighbors.jl")
2130
include("learn.jl")
2231

2332
export
@@ -32,6 +41,10 @@ export
3241
LinearRegressor,
3342
GeneralizedLinearRegressor,
3443

44+
# NearestNeighbors.jl
45+
KNNClassifier,
46+
KNNRegressor,
47+
3548
# transform
3649
Learn
3750

src/models/decisiontree.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1-
const DTModel = Union{
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
const DecisionTreeModel = Union{
26
AdaBoostStumpClassifier,
37
DecisionTreeClassifier,
48
RandomForestClassifier,
59
DecisionTreeRegressor,
610
RandomForestRegressor
711
}
812

9-
function fit(model::DTModel, input, output)
13+
function fit(model::DecisionTreeModel, input, output)
1014
cols = Tables.columns(output)
1115
names = Tables.columnnames(cols)
12-
outcol = first(names)
13-
y = Tables.getcolumn(cols, outcol)
16+
outnm = first(names)
17+
y = Tables.getcolumn(cols, outnm)
1418
X = Tables.matrix(input)
1519
DT.fit!(model, X, y)
16-
FittedModel(model, outcol)
20+
FittedModel(model, outnm)
1721
end
1822

19-
function predict(fmodel::FittedModel{<:DTModel}, table)
20-
outcol = fmodel.cache
23+
function predict(fmodel::FittedModel{<:DecisionTreeModel}, table)
24+
outnm = fmodel.cache
2125
X = Tables.matrix(table)
2226
= DT.predict(fmodel.model, X)
23-
(; outcol => ŷ) |> Tables.materializer(table)
27+
(; outnm => ŷ) |> Tables.materializer(table)
2428
end

src/models/glm.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
15
abstract type GLMModel end
26

37
struct LinearRegressor{K} <: GLMModel
@@ -18,18 +22,18 @@ GeneralizedLinearRegressor(dist::UnivariateDistribution, link=nothing; kwargs...
1822
function fit(model::GLMModel, input, output)
1923
cols = Tables.columns(output)
2024
names = Tables.columnnames(cols)
21-
outcol = first(names)
25+
outnm = first(names)
2226
X = Tables.matrix(input)
23-
y = Tables.getcolumn(cols, outcol)
27+
y = Tables.getcolumn(cols, outnm)
2428
fitted = _fit(model, X, y)
25-
FittedModel(model, (fitted, outcol))
29+
FittedModel(model, (fitted, outnm))
2630
end
2731

2832
function predict(fmodel::FittedModel{<:GLMModel}, table)
29-
model, outcol = fmodel.cache
33+
model, outnm = fmodel.cache
3034
X = Tables.matrix(table)
3135
= GLM.predict(model, X)
32-
(; outcol => ŷ) |> Tables.materializer(table)
36+
(; outnm => ŷ) |> Tables.materializer(table)
3337
end
3438

3539
_fit(model::LinearRegressor, X, y) = GLM.lm(X, y; model.kwargs...)

src/models/nearestneighbors.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
abstract type NearestNeighborsModel end
6+
7+
struct KNNClassifier{M<:Metric} <: NearestNeighborsModel
8+
k::Int
9+
metric::M
10+
leafsize::Int
11+
reorder::Bool
12+
end
13+
14+
KNNClassifier(k, metric=Euclidean(); leafsize=10, reorder=true) = KNNClassifier(k, metric, leafsize, reorder)
15+
16+
struct KNNRegressor{M<:Metric} <: NearestNeighborsModel
17+
k::Int
18+
metric::M
19+
leafsize::Int
20+
reorder::Bool
21+
end
22+
23+
KNNRegressor(k, metric=Euclidean(); leafsize=10, reorder=true) = KNNRegressor(k, metric, leafsize, reorder)
24+
25+
function fit(model::NearestNeighborsModel, input, output)
26+
cols = Tables.columns(output)
27+
outnm = Tables.columnnames(cols) |> first
28+
outcol = Tables.getcolumn(cols, outnm)
29+
_checkoutput(model, outcol)
30+
(; metric, leafsize, reorder) = model
31+
data = Tables.matrix(input, transpose=true)
32+
tree = if metric isa MinkowskiMetric
33+
NN.KDTree(data, metric; leafsize, reorder)
34+
else
35+
NN.BallTree(data, metric; leafsize, reorder)
36+
end
37+
FittedModel(model, (tree, outnm, outcol))
38+
end
39+
40+
function predict(fmodel::FittedModel{<:NearestNeighborsModel}, table)
41+
(; model, cache) = fmodel
42+
tree, outnm, outcol = cache
43+
data = Tables.matrix(table, transpose=true)
44+
indvec, _ = NN.knn(tree, data, model.k)
45+
aggfun = _aggfun(model)
46+
= [aggfun(outcol[inds]) for inds in indvec]
47+
(; outnm => ŷ) |> Tables.materializer(table)
48+
end
49+
50+
function _checkoutput(::KNNClassifier, x)
51+
if !(elscitype(x) <: DST.Categorical)
52+
throw(ArgumentError("output column must be categorical"))
53+
end
54+
end
55+
56+
function _checkoutput(::KNNRegressor, x)
57+
if !(elscitype(x) <: DST.Continuous)
58+
throw(ArgumentError("output column must be continuous"))
59+
end
60+
end
61+
62+
_aggfun(::KNNClassifier) = mode
63+
_aggfun(::KNNRegressor) = mean

test/runtests.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,29 @@ const SLM = StatsLearnModels
3636
@test accuracy > 0.9
3737
end
3838

39+
@testset "NearestNeighbors" begin
40+
Random.seed!(123)
41+
model = KNNClassifier(5)
42+
fmodel = SLM.fit(model, input[train, :], output[train, :])
43+
pred = SLM.predict(fmodel, input[test, :])
44+
accuracy = count(pred.target .== output.target[test]) / length(test)
45+
@test accuracy > 0.9
46+
47+
Random.seed!(123)
48+
a = rand(1:0.1:10, 100)
49+
b = rand(1:0.1:10, 100)
50+
y = 2a + b
51+
input = DataFrame(; a, b)
52+
output = DataFrame(; y)
53+
model = KNNRegressor(5)
54+
fmodel = SLM.fit(model, input, output)
55+
pred = SLM.predict(fmodel, input)
56+
@test count(isapprox.(pred.y, y, atol=0.8)) > 80
57+
58+
@test_throws ArgumentError SLM.fit(KNNClassifier(5), input, output)
59+
@test_throws ArgumentError SLM.fit(KNNRegressor(5), input, rand('a':'z', 100))
60+
end
61+
3962
@testset "GLM" begin
4063
x = [1, 2, 3]
4164
y = [2, 4, 7]

0 commit comments

Comments
 (0)