Skip to content

Commit 8a260f7

Browse files
committed
add stack_evaluation; needs JuliaAI/MLJBase.jl#767
add stack_evaluation; needs JuliaAI/MLJBase.jl#767 rm target_scitype arg from stack_evaluation put stack test into test() oops fix some bugs separate out :accelerated_stack_evaluation test more tweaks oops
1 parent 81eae18 commit 8a260f7

File tree

9 files changed

+291
-89
lines changed

9 files changed

+291
-89
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.0"
66
[deps]
77
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
88
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
9+
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
910
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

examples/bigtest/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
44
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
55
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
66
LightGBM = "7acf609c-83a4-11e9-1ffb-b912bcd3b04a"
7+
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
78
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
89
MLJClusteringInterface = "d354fa79-ed1c-40d4-88ef-b8c7bd1568af"
910
MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661"

examples/bigtest/notebook.jl

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using DataFrames # for displaying tables
1616
# # Regression
1717

1818
known_problems = models() do model
19+
!model.is_pure_julia ||
1920
any([
2021
# https://github.com/lalvim/PartialLeastSquaresRegressor.jl/issues/29
2122
model.package_name == "PartialLeastSquaresRegressor",
@@ -28,14 +29,15 @@ end
2829
MLJTestIntegration.test_single_target_regressors(
2930
known_problems,
3031
ignore=true,
31-
level=1
32+
level=1,
3233
)
3334

3435
fails1, report1 =
3536
MLJTestIntegration.test_single_target_regressors(
3637
known_problems,
3738
ignore=true,
38-
level=4
39+
level=4,
40+
verbosity=2,
3941
)
4042

4143
fails1 |> DataFrame
@@ -47,31 +49,42 @@ report1 |> DataFrame
4749

4850
# # Classification
4951

50-
# https://github.com/alan-turing-institute/MLJ.jl/issues/939
51-
known_problems = [
52-
(name = "KernelPerceptronClassifier", package_name="BetaML"),
53-
(name = "DecisionTreeClassifier", package_name="BetaML"),
54-
(name = "PerceptronClassifier", package_name="BetaML"),
55-
(name = "NuSVC", package_name="LIBSVM"),
56-
(name="PegasosClassifier", package_name="BetaML"),
57-
(name="RandomForestClassifier", package_name="BetaML"),
58-
(name="SVMNuClassifier", package_name="ScikitLearn"),
59-
(name="KernelPerceptronClassifier", package_name="BetaML"),
60-
(name="LinearSVC", package_name="LIBSVM"),
61-
(name= "MultinomialClassifier", "MLJLinearModels"),
62-
(name="SVMLinearClassifier", package_name="ScikitLearn"),
63-
]
52+
known_problems = models() do model
53+
!model.is_pure_julia ||
54+
(name = model.name, package_name = model.package_name) in
55+
[
56+
# https://github.com/JuliaAI/MLJMultivariateStatsInterface.jl/issues/41
57+
(name = "LDA", package_name = "MultivariateStats"),
58+
(name = "SubspaceLDA", package_name = "MultivariateStats"),
59+
(name = "BayesianLDA", package_name = "MultivariateStats"),
60+
(name = "BayesianSubspaceLDA", package_name = "MultivariateStats"),
61+
62+
# https://github.com/alan-turing-institute/MLJ.jl/issues/939
63+
(name = "DecisionTreeClassifier", package_name="BetaML"),
64+
(name = "PerceptronClassifier", package_name="BetaML"),
65+
(name = "NuSVC", package_name="LIBSVM"),
66+
(name="PegasosClassifier", package_name="BetaML"),
67+
(name="RandomForestClassifier", package_name="BetaML"),
68+
(name="SVMNuClassifier", package_name="ScikitLearn"),
69+
(name="KernelPerceptronClassifier", package_name="BetaML"),
70+
(name="LinearSVC", package_name="LIBSVM"),
71+
(name= "MultinomialClassifier", package_name="MLJLinearModels"),
72+
(name="SVMLinearClassifier", package_name="ScikitLearn"),
73+
]
74+
end
6475

6576
MLJTestIntegration.test_single_target_classifiers(
6677
known_problems,
6778
level=1,
6879
ignore=true,
6980
)
81+
7082
fails2, report2 =
7183
MLJTestIntegration.test_single_target_classifiers(
7284
known_problems,
7385
ignore=true,
7486
level=4,
87+
verbosity=2
7588
)
7689

7790
fails2 |> DataFrame

src/MLJTestIntegration.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
module MLJTestIntegration
22

3-
const N_MODELS_FOR_REPEATABILITY_TEST = 3
3+
const N_MODELS_FOR_REPEATABILITY_TEST = 50
44

55
using MLJ
66
using Pkg
77
using .Threads
88
using Test
9+
using NearestNeighborModels
910

1011
include("attemptors.jl")
1112
include("test.jl")
@@ -14,7 +15,6 @@ include("dummy_model.jl")
1415

1516
function __init__()
1617
global RESOURCES = (CPU1(), CPUThreads())
17-
@info "Testing with $(nthreads()) threads. "
1818
end
1919

2020
using .DummyModel

src/attemptors.jl

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ function threshold_prediction(model, data...; throw=false, verbosity=1)
127127
end
128128

129129
function evaluation(measure, model, resources, data...; throw=false, verbosity=1)
130-
message = "[:evaluation] Evaluating performance "
130+
L = length(resources)
131+
message = L > 1 ? "[:accelerated_evaluation] " : "[evaluation] "
132+
message *= "Evaluating model performance using with $L resources. "
131133
attempt(finalize(message, verbosity); throw) do
132134
es = map(resources) do accel
133135
evaluate(model, data...;
@@ -136,7 +138,7 @@ function evaluation(measure, model, resources, data...; throw=false, verbosity=1
136138
acceleration=accel,
137139
verbosity=0)
138140
end
139-
ms = map(e->e.measurement, es)
141+
ms = map(e->sort(e.per_fold[1]), es)
140142
m = first(ms)
141143
@assert all((m), collect(ms)[2:end]) ERR_INCONSISTENT_RESULTS
142144
return first(es)
@@ -177,6 +179,7 @@ function ensemble_prediction(model, data...; throw=false, verbosity=1)
177179
end
178180
end
179181

182+
# the `model` must support iteration (`!isnothing(iteration_paramater(model))`)
180183
function iteration_prediction(measure, model, data...; throw=false, verbosity=1)
181184
message = "[:iteration_prediction] Iterating with controls "
182185
attempt(finalize(message, verbosity); throw) do
@@ -190,3 +193,59 @@ function iteration_prediction(measure, model, data...; throw=false, verbosity=1)
190193
predict(mach, first(data))
191194
end
192195
end
196+
197+
# the `model` can only be single-target deterministic regressor or
198+
# probabilistic classifier.
199+
function stack_evaluation(
200+
model,
201+
resources,
202+
data...;
203+
throw=false,
204+
verbosity=1
205+
)
206+
L = length(resources)
207+
message = L > 1 ? "[:accelerated_stack_evaluation] " : "[stack_evaluation] "
208+
message *= "Evaluating a stack containing model "*
209+
"with $L resources. "
210+
target_scitype = MLJ.target_scitype(model)
211+
if AbstractVector{Continuous} <: target_scitype
212+
models = (knn1=KNNRegressor(K=4),
213+
knn2=KNNRegressor(K=6),
214+
model=model)
215+
metalearner = KNNRegressor()
216+
measure = LPLoss(2)
217+
else
218+
models = (knn1=KNNClassifier(K=4),
219+
knn2=KNNClassifier(K=6),
220+
model=model)
221+
metalearner = KNNClassifier()
222+
measure = BrierScore()
223+
# models = (tree=DecisionTreeClassifier(),
224+
# knn=KNNClassifier(K=6),
225+
# model=model)
226+
# metalearner = KNNClassifier()
227+
# measure = BrierScore()
228+
end
229+
attempt(finalize(message, verbosity); throw) do
230+
es = map(resources) do accel
231+
mystack = Stack(
232+
; metalearner,
233+
resampling=CV(;nfolds=3),
234+
acceleration=accel,
235+
models...)
236+
237+
evaluate(
238+
mystack,
239+
data...;
240+
measure=measure,
241+
resampling=Holdout(),
242+
verbosity=0,
243+
)
244+
end |> collect
245+
ms = map(e->sort(e.per_fold[1]), es)
246+
m = first(ms)
247+
# @show ms
248+
@assert all((m), ms[2:end]) ERR_INCONSISTENT_RESULTS
249+
first(es)
250+
end
251+
end

src/special_cases.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ strip(proxy) = (name=proxy.name, package_name=proxy.package_name)
1818
function actual_proxies(raw_proxies, data, ignore, verbosity)
1919
if !(raw_proxies isa Vector)
2020
raw_proxies = [raw_proxies, ]
21-
end
21+
end
2222
proxies = strip.(raw_proxies)
2323
from_registry = strip.(models(matching(data...)))
2424
if ignore
@@ -34,7 +34,7 @@ function actual_proxies(raw_proxies, data, ignore, verbosity)
3434
end
3535

3636
function _test(proxies, data; ignore::Bool=false, verbosity=1, kwargs...)
37-
test(actual_proxies(proxies, data, ignore, verbosity), data...; kwargs...)
37+
test(actual_proxies(proxies, data, ignore, verbosity), data...; verbosity, kwargs...)
3838
end
3939
_test(data; ignore=true, kwargs...) = _test([], data; ignore, kwargs...)
4040

0 commit comments

Comments
 (0)