Skip to content

Commit bebbf55

Browse files
committed
make :evaluation a cv test to make repeatability test more sensitive
1 parent 3eb1e8e commit bebbf55

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

examples/bigtest/Manifest.toml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
7878

7979
[[deps.BetaML]]
8080
deps = ["CategoricalArrays", "Combinatorics", "DelimitedFiles", "Distributions", "ForceImport", "LinearAlgebra", "MLJModelInterface", "PDMats", "Printf", "ProgressMeter", "Random", "Reexport", "StableRNGs", "Statistics", "Test", "Zygote"]
81-
git-tree-sha1 = "d80754995042cfa9233611f9b80dfa17483c42fe"
81+
git-tree-sha1 = "487007edd486b6be32c14f63637efdf0a957a38e"
8282
uuid = "024491cd-cc6b-443e-8034-08ea7eb7db2b"
83-
version = "0.6.0"
83+
version = "0.6.1"
8484

8585
[[deps.BinaryProvider]]
8686
deps = ["Libdl", "Logging", "SHA"]
@@ -143,9 +143,9 @@ version = "0.1.5"
143143

144144
[[deps.ChainRules]]
145145
deps = ["ChainRulesCore", "Compat", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics"]
146-
git-tree-sha1 = "e9023f88b1655ffc6a4aaef2502878e8116151ef"
146+
git-tree-sha1 = "34e265b1b0049896430625ce1638b2719c783c6b"
147147
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
148-
version = "1.35.1"
148+
version = "1.35.2"
149149

150150
[[deps.ChainRulesCore]]
151151
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
@@ -759,11 +759,11 @@ version = "0.18.2"
759759

760760
[[deps.MLJBase]]
761761
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
762-
git-tree-sha1 = "14740ef3eec192e0b48a601fbb810cb64f7c7dca"
762+
git-tree-sha1 = "6ddcfc397fc589114d6a9008d6c2311b818cec3e"
763763
repo-rev = "stack_cache_and_acceleration_rebased"
764764
repo-url = "https://github.com/JuliaAI/MLJBase.jl"
765765
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
766-
version = "0.20.4"
766+
version = "0.20.5"
767767

768768
[[deps.MLJClusteringInterface]]
769769
deps = ["Clustering", "Distances", "MLJModelInterface"]
@@ -850,7 +850,7 @@ uuid = "7fa162e1-0e29-41ca-a6fa-c000ca4e7e7e"
850850
version = "0.1.4"
851851

852852
[[deps.MLJTestIntegration]]
853-
deps = ["MLJ", "MLJTuning", "NearestNeighborModels", "Pkg", "Test"]
853+
deps = ["MLJ", "MLJBase", "MLJTuning", "NearestNeighborModels", "Pkg", "Test"]
854854
path = "/Users/anthony/MLJ/MLJTestIntegration"
855855
uuid = "697918b4-fdc1-4f9e-8ff9-929724cee270"
856856
version = "0.1.0"
@@ -936,10 +936,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
936936
version = "7.8.2"
937937

938938
[[deps.NNlib]]
939-
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
940-
git-tree-sha1 = "f89de462a7bc3243f95834e75751d70b3a33e59d"
939+
deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
940+
git-tree-sha1 = "a0331452b4cfd5e53ee2325376794aea47364d5a"
941941
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
942-
version = "0.8.5"
942+
version = "0.8.7"
943943

944944
[[deps.NNlibCUDA]]
945945
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
@@ -1290,9 +1290,9 @@ version = "0.6.6"
12901290

12911291
[[deps.StaticArrays]]
12921292
deps = ["LinearAlgebra", "Random", "Statistics"]
1293-
git-tree-sha1 = "383a578bdf6e6721f480e749d503ebc8405a0b22"
1293+
git-tree-sha1 = "2bbd9f2e40afd197a1379aef05e0d85dba649951"
12941294
uuid = "90137ffa-7385-5640-81b9-e52037218182"
1295-
version = "1.4.6"
1295+
version = "1.4.7"
12961296

12971297
[[deps.StatisticalTraits]]
12981298
deps = ["ScientificTypesBase"]

examples/bigtest/notebook.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ known_problems = models() do model
6767
(name = "DecisionTreeClassifier", package_name="BetaML"),
6868
(name="RandomForestClassifier", package_name="BetaML"),
6969

70-
# https://github.com/sylvaticus/BetaML.jl/issues/32
71-
(name = "KernelPerceptronClassifier", package_name="BetaML"),
72-
7370
# https://github.com/alan-turing-institute/MLJ.jl/issues/939
7471
(name = "NuSVC", package_name="LIBSVM"),
7572
(name="SVMNuClassifier", package_name="ScikitLearn"),
@@ -98,3 +95,6 @@ fails2 |> DataFrame
9895
#-
9996

10097
report2 |> DataFrame
98+
99+
100+

src/attemptors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function evaluation(measure, model, resources, data...; throw=false, verbosity=1
134134
es = map(resources) do resource
135135
evaluate(model, data...;
136136
measure=measure,
137-
resampling=Holdout(),
137+
resampling=CV(nfolds=4),
138138
acceleration=resource,
139139
verbosity=0)
140140
end

src/test.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ These additional tests are applied to `Supervised` models:
130130
131131
- `:evaluation`: Assuming MLJ is able to infer a suitable `measure`
132132
(metric), evaluate the performance of the model using `evaluate!`
133-
and a `Holdout` set.
133+
and and cross-validation.
134134
135135
- `:accelerated_evaluation`: Assuming the model appears to make
136136
repeatable predictions on retraining, repeat the `:evaluation` test
@@ -317,8 +317,9 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
317317
# determine computational resources:
318318
resources = MLJ.AbstractResource[CPU1(),] # fallback
319319
if level > 3
320-
per_fold = evaluation.per_fold[1]
321-
per_folds = map(1:(N_MODELS_FOR_REPEATABILITY_TEST - 1)) do i
320+
baseline = evaluation.per_fold[1]
321+
repeatable = true
322+
for i in 1:(N_MODELS_FOR_REPEATABILITY_TEST - 1)
322323
verbosity > 1 && print(
323324
"\rInternal repeatability tests, "*
324325
"$(i + 1) of $N_MODELS_FOR_REPEATABILITY_TEST trials complete"
@@ -332,10 +333,13 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
332333
verbosity=0,
333334
)
334335
o == "" || return nothing
335-
e.per_fold[1]
336+
if !(e.per_fold[1] baseline)
337+
repeatable = false
338+
break
339+
end
336340
end
337341
verbosity > 1 && print("")
338-
if all((per_fold), per_folds)
342+
if repeatable
339343
resources = RESOURCES
340344
verbosity > 1 && println(" Repeatable.")
341345
else

0 commit comments

Comments
 (0)