Skip to content

Commit 51ad597

Browse files
committed
change signature of convenience functions, with ignore now Bool
1 parent 4d9d0f0 commit 51ad597

File tree

5 files changed

+80
-49
lines changed

5 files changed

+80
-49
lines changed

examples/bigtest/notebook.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using DataFrames # for displaying tables
1515

1616
# # Regression
1717

18-
known_issues = models() do model
18+
known_problems = models() do model
1919
any([
2020
# https://github.com/lalvim/PartialLeastSquaresRegressor.jl/issues/29
2121
model.package_name == "PartialLeastSquaresRegressor",
@@ -25,25 +25,41 @@ known_issues = models() do model
2525
])
2626
end
2727

28-
MLJTestIntegration.test_single_target_regressors(ignore=known_issues, level=1)
29-
fails, summary =
30-
MLJTestIntegration.test_single_target_regressors(ignore=known_issues, level=3)
28+
MLJTestIntegration.test_single_target_regressors(
29+
known_problems,
30+
ignore=true,
31+
level=1
32+
)
33+
34+
fails, report =
35+
MLJTestIntegration.test_single_target_regressors(
36+
known_problems,
37+
ignore=true,
38+
level=3
39+
)
3140

3241
@test isempty(fails)
33-
summary |> DataFrame
42+
report |> DataFrame
3443

3544

3645
# # Classification
3746

3847
# https://github.com/alan-turing-institute/MLJ.jl/issues/939
39-
known_issues = [
48+
known_problems = [
4049
(name = "DecisionTreeClassifier", package_name="BetaML"),
50+
(name = "PerceptronClassifier", package_name="BetaML"),
4151
(name = "NuSVC", package_name="LIBSVM"),
4252
(name="PegasosClassifier", package_name="BetaML"),
4353
(name="RandomForestClassifier", package_name="BetaML"),
4454
(name="SVMNuClassifier", package_name="ScikitLearn"),
4555
]
4656

47-
MLJTestIntegration.test_single_target_classifiers(ignore=known_issues, level=1)
48-
fails, summary =
49-
MLJTestIntegration.test_single_target_classifiers(ignore=known_issues, level=3)
57+
MLJTestIntegration.test_single_target_classifiers(
58+
known_problems,
59+
level=1
60+
)
61+
fails, report =
62+
MLJTestIntegration.test_single_target_classifiers(
63+
known_problems,
64+
level=3,
65+
)

src/special_cases.jl

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,31 @@
11
# # HELPERS
22

3-
_strip(proxy) = (name=proxy.name, package_name=proxy.package_name)
4-
5-
function _filter(proxies, bad)
6-
sbad = _strip.(bad)
7-
filter(proxies) do proxy
8-
!(_strip(proxy) in sbad)
9-
end
3+
function warn_not_testing_these(models)
4+
"Not testing the following models, as incompatible with testing data:\n"*
5+
"$models"
106
end
117

12-
# fallback:
13-
function _test(data, ignore; kwargs...)
14-
proxies = _filter(models(matching(data...)), ignore)
15-
test(proxies, data...; kwargs...)
8+
strip(proxy) = (name=proxy.name, package_name=proxy.package_name)
9+
10+
function actual_proxies(raw_proxies, data, ignore, verbosity)
11+
proxies = strip.(raw_proxies)
12+
from_registry = strip.(models(matching(data...)))
13+
if ignore
14+
actual_proxies = setdiff(from_registry, proxies)
15+
else
16+
actual_proxies = intersect(proxies, from_registry)
17+
rejected = setdiff(proxies, actual_proxies)
18+
if !isempty(rejected) && verbosity > 0
19+
@warn warn_not_testing_these(rejected)
20+
end
21+
end
22+
return actual_proxies
1623
end
1724

18-
# when there are no models to exclude:
19-
function _test(data, ignore::Nothing; kwargs...)
20-
proxies = models(matching(data...))
21-
test(proxies, data...; kwargs...)
25+
function _test(proxies, data; ignore::Bool=false, verbosity=1, kwargs...)
26+
test(actual_proxies(proxies, data, ignore, verbosity), data...; kwargs...)
2227
end
28+
_test(data; ignore=true, kwargs...) = _test([], data; ignore, kwargs...)
2329

2430

2531
# # SINGLE TARGET CLASSIFICATION
@@ -31,8 +37,8 @@ function _make_binary()
3137
return X, y
3238
end
3339

34-
test_single_target_classifiers(; ignore=nothing, kwargs...) =
35-
_test(_make_binary(), ignore; kwargs...)
40+
test_single_target_classifiers(args...; kwargs...) =
41+
_test(args..., _make_binary(); kwargs...)
3642

3743

3844
# # SINGLE TARGET REGRESSION
@@ -43,8 +49,8 @@ function _make_baby_boston()
4349
return X, y
4450
end
4551

46-
test_single_target_regressors(; ignore=nothing, kwargs...) =
47-
_test(_make_baby_boston(), ignore; kwargs...)
52+
test_single_target_regressors(args...; kwargs...) =
53+
_test(args..., _make_baby_boston(); kwargs...)
4854

4955

5056
# # SINGLE TARGET COUNT REGRESSORS
@@ -55,13 +61,13 @@ function _make_count()
5561
return X, y
5662
end
5763

58-
test_single_target_count_regressors(; ignore=nothing, kwargs...) =
59-
_test(_make_count(), ignore; kwargs...)
64+
test_single_target_count_regressors(args...; kwargs...) =
65+
_test(args..., _make_count(); kwargs...)
6066

6167

6268
# # CONTINUOUS TABLE TRANSFORMERS
6369

6470
_make_transformer() = (first(_make_baby_boston()),)
6571

66-
test_continuous_table_transformers(; ingore=nothing, kwargs...) =
67-
_test(_make_transformer(), ignore; kwargs...)
72+
test_continuous_table_transformers(args...; kwargs...) =
73+
_test(args..., _make_transformer(); kwargs...)

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ using Pkg
33
using MLJTestIntegration
44
using MLJTestIntegration.MLJ
55
using MLJTestIntegration.MLJ.MLJModels
6+
const MTI = MLJTestIntegration
67

78
# enable conditional testing of modules by providing test_args
89
# e.g. `Pkg.test("MLJBase", test_args=["misc"])`
9-
RUN_ALL_TESTS = isempty(ARGS)
10+
11+
const RUN_ALL_TESTS = isempty(ARGS)
1012
macro conditional_testset(name, expr)
1113
name = string(name)
1214
esc(quote

test/special_cases.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
1-
@testset "_filter" begin
2-
proxies = [
3-
(name="1", package_name="A", extra="cat"),
4-
(name="1", package_name="B", extra="mouse"),
5-
(name="2", package_name="B", extra="dog"),
6-
(name="1", package_name="C", extra="rat"),
7-
]
1+
classifiers = [
2+
(name = "ConstantClassifier", package_name = "MLJModels"),
3+
(name = "DeterministicConstantClassifier", package_name = "MLJModels")
4+
]
85

9-
bad = [
10-
(name="1", package_name="A"),
11-
(name="1", package_name="B"),
12-
]
6+
regressors = [
7+
(name = "ConstantRegressor", package_name = "MLJModels"),
8+
(name = "DeterministicConstantRegressor", package_name = "MLJModels")
9+
]
1310

14-
@test MLJTestIntegration._filter(proxies, bad) == [
15-
(name="2", package_name="B", extra="dog"),
16-
(name="1", package_name="C", extra="rat"),
17-
]
11+
@testset "actual_proxies" begin
12+
data = MTI._make_baby_boston()
13+
proxies = @test_logs MTI.actual_proxies(regressors, data, false, 1)
14+
@test proxies == regressors
15+
proxies2 = @test_logs MTI.actual_proxies(regressors, data, true, 1)
16+
@test proxies2 == setdiff(MTI.strip.(models(matching(data...))), regressors)
17+
proxies = @test_logs(
18+
(:warn, MTI.warn_not_testing_these(classifiers)),
19+
MTI.actual_proxies(vcat(regressors, classifiers), data, false, 1),
20+
)
21+
@test proxies == regressors
22+
proxies = @test_logs(
23+
MTI.actual_proxies(vcat(regressors, classifiers), data, true, 1),
24+
)
25+
@test proxies == proxies2
1826
end

test/test.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# grab some classifiers from MLJModels:
21
classifiers = [
32
(name = "ConstantClassifier", package_name = "MLJModels"),
43
(name = "DeterministicConstantClassifier", package_name = "MLJModels")

0 commit comments

Comments
 (0)