Skip to content

Commit 94341af

Browse files
authored
Merge pull request #20 from JuliaAI/restrict-ensemble-testing
Restrict ensemble testing to exclude models with `Count` targets; expose dataset generating functions
2 parents 5929f2f + 7c27a12 commit 94341af

File tree

4 files changed

+76
-27
lines changed

4 files changed

+76
-27
lines changed

README.md

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ Query the document strings for details, or see
4141

4242
## Testing models in a new MLJ model interface implementation
4343

44-
The following tests the model interface implemented by some model type
45-
`MyClassifier`, as might appear in tests for a package providing that
46-
type:
44+
The following tests the model interface implemented by some model type `MyClassifier` for
45+
multiclass classification, as might appear in tests for a package providing that type:
4746

4847
```julia
4948
import MLJTestIntegration
5049
using Test
51-
X, y = MLJTestIntegration.MLJ.make_blobs()
50+
X, y = MLJTestIntegration.make_multiclass()
5251
failures, summary = MLJTestIntegration.test([MyClassifier, ], X, y, verbosity=1, mod=@__MODULE__)
5352
@test isempty(failures)
5453
```
@@ -78,3 +77,17 @@ failures, summary =
7877

7978
summary |> DataFrame
8079
```
80+
81+
# Datasets
82+
83+
The following commands generate datasets of the form `(X, y)` suitable for integration
84+
tests:
85+
86+
- `MLJTestIntegration.make_binary`
87+
88+
- `MLJTestIntegration.make_multiclass`
89+
90+
- `MLJTestIntegration.make_regression`
91+
92+
- `MLJTestIntegration.make_count`
93+

src/special_cases.jl

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,60 @@ end
3939
_test(data; ignore=true, kwargs...) = _test([], data; ignore, kwargs...)
4040

4141

42-
# # SINGLE TARGET CLASSIFICATION
42+
# # BABY DATA SETS
43+
44+
"""
45+
make_binary()
4346
44-
function _make_binary()
47+
Return data `(X, y)` for the crabs dataset, restricted to the two features `:FL`,
48+
`:RW`. Target is `Multiclass{2}`.
49+
50+
"""
51+
function make_binary()
4552
data = MLJ.load_crabs()
4653
y_, X = unpack(data, ==(:sp), col->col in [:FL, :RW])
4754
y = coerce(y_, MLJ.OrderedFactor)
4855
return X, y
4956
end
5057

58+
"""
59+
make_multiclass()
60+
61+
Return data `(X, y)` for the unshuffled iris dataset. Target is `Multiclass{3}`.
62+
63+
"""
64+
make_multiclass() = MLJ.@load_iris
65+
66+
"""
67+
make_regression()
68+
69+
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
70+
`:Rm`. Target is `Continuous`.
71+
72+
"""
73+
function make_regression()
74+
data = MLJ.load_boston()
75+
y, X = unpack(data, ==(:MedV), col->col in [:LStat, :Rm])
76+
return X, y
77+
end
78+
79+
"""
80+
make_regression()
81+
82+
Return data `(X, y)` for the Boston dataset, restricted to the two features `:LStat`,
83+
`:Rm`, with the `Continuous` target converted to `Count` (integer).
84+
85+
"""
86+
function make_count()
87+
X, y_ = make_regression()
88+
y = map-> round(Int, η), y_)
89+
return X, y
90+
end
91+
92+
93+
# # SINGLE TARGET CLASSIFICATION
94+
95+
5196
"""
5297
MLJTestIntegration.test_single_target_classifiers(; keyword_options...)
5398
@@ -62,17 +107,11 @@ $DOC_AS_ABOVE
62107
63108
"""
64109
test_single_target_classifiers(args...; kwargs...) =
65-
_test(args..., _make_binary(); kwargs...)
110+
_test(args..., make_binary(); kwargs...)
66111

67112

68113
# # SINGLE TARGET REGRESSION
69114

70-
function _make_baby_boston()
71-
data = MLJ.load_boston()
72-
y, X = unpack(data, ==(:MedV), col->col in [:LStat, :Rm])
73-
return X, y
74-
end
75-
76115
"""
77116
MLJTestIntegration.test_single_target_regressors(; keyword_options...)
78117
@@ -87,17 +126,11 @@ $DOC_AS_ABOVE
87126
88127
"""
89128
test_single_target_regressors(args...; kwargs...) =
90-
_test(args..., _make_baby_boston(); kwargs...)
129+
_test(args..., make_regression(); kwargs...)
91130

92131

93132
# # SINGLE TARGET COUNT REGRESSORS
94133

95-
function _make_count()
96-
X, y_ = _make_baby_boston()
97-
y = map-> round(Int, η), y_)
98-
return X, y
99-
end
100-
101134
"""
102135
MLJTestIntegration.test_single_count_regressors(; keyword_options...)
103136
@@ -114,12 +147,12 @@ $DOC_AS_ABOVE
114147
115148
"""
116149
test_single_target_count_regressors(args...; kwargs...) =
117-
_test(args..., _make_count(); kwargs...)
150+
_test(args..., make_count(); kwargs...)
118151

119152

120153
# # CONTINUOUS TABLE TRANSFORMERS
121154

122-
_make_transformer() = (first(_make_baby_boston()),)
155+
_make_transformer() = (first(make_regression()),)
123156

124157
"""
125158
test_continuous_table_transformers(; keyword_options...)

src/test.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ function next!(p)
99
MLJ.ProgressMeter.updateProgress!(p)
1010
end
1111

12+
const ENSEMBLE_TARGET_ELSCITYPE = Union{Missing, Continuous, Finite}
13+
1214
"""
1315
test(models, data...; mod=Main, level=2, throw=false, verbosity=1)
1416
@@ -382,15 +384,16 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
382384
outcome == "×" && continue
383385

384386
#[ensemble_prediction]:
385-
ensemble_prediction, outcome =
386-
MLJTestIntegration.ensemble_prediction(
387+
if target_scitype(model_type) <: AbstractVector{<:ENSEMBLE_TARGET_ELSCITYPE}
388+
ensemble_prediction, outcome = MLJTestIntegration.ensemble_prediction(
387389
model_instance,
388390
data...;
389391
throw,
390392
verbosity,
391393
)
392-
row = update(row, i, :ensemble_prediction, ensemble_prediction, outcome)
393-
outcome == "×" && continue
394+
row = update(row, i, :ensemble_prediction, ensemble_prediction, outcome)
395+
outcome == "×" && continue
396+
end
394397

395398
# [iteration_prediction]:
396399
if !isnothing(iteration_parameter(model_instance))

test/special_cases.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ regressors = [
99
]
1010

1111
@testset "actual_proxies" begin
12-
data = MTI._make_baby_boston()
12+
data = MTI.make_regression()
1313
proxies = @test_logs MTI.actual_proxies(regressors, data, false, 1)
1414
@test proxies == regressors
1515
proxies2 = @test_logs MTI.actual_proxies(regressors, data, true, 1)

0 commit comments

Comments
 (0)