Skip to content

Commit dc2b045

Browse files
committed
add level=4 accelerated_evaluation test
1 parent f42bd2a commit dc2b045

File tree

6 files changed

+164
-28
lines changed

6 files changed

+164
-28
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ version = "0.1.0"
55

66
[deps]
77
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
8+
MLJTuning = "03970b2e-30c4-11ea-3135-d1576263f10f"
89
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
10+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
911

1012
[compat]
1113
MLJ = "0.18"

src/MLJTestIntegration.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,22 @@
11
module MLJTestIntegration
22

3+
const N_MODELS_FOR_REPEATABILITY_TEST = 3
4+
35
using MLJ
46
using Pkg
7+
using .Threads
8+
using Test
59

610
include("attemptors.jl")
711
include("test.jl")
812
include("special_cases.jl")
913
include("dummy_model.jl")
1014

15+
function __init__()
16+
global RESOURCES = (CPU1(), CPUThreads())
17+
@info "Testing with $(nthreads()) threads. "
18+
end
19+
1120
using .DummyModel
1221

1322
end # module

src/attemptors.jl

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
const ERR_INCONSISTENT_RESULTS =
2+
"Different computational resources are giving different results. "
3+
14
"""
25
attempt(f, message; throw=false)
36
47
Return `(f(), "✓") if `f()` executes without throwing an
58
exception. Otherwise, return `(ex, "×"), where `ex` is the exception
6-
caught. Only truly throw the exception if `throw=true`.
9+
caught. Only truly throw the exception if `throw=true`.
710
811
If `message` is not empty, then it is logged to `Info`, together with
912
the second return value ("✓" or "×").
@@ -123,32 +126,51 @@ function threshold_prediction(model, data...; throw=false, verbosity=1)
123126
end
124127
end
125128

126-
function evaluation(measure, model, data...; throw=false, verbosity=1)
129+
function evaluation(measure, model, resources, data...; throw=false, verbosity=1)
127130
message = "[:evaluation] Evaluating performance "
128131
attempt(finalize(message, verbosity); throw) do
129-
evaluate(model, data...;
130-
measure=measure,
131-
resampling=Holdout(),
132-
verbosity=0)
132+
es = map(resources) do accel
133+
evaluate(model, data...;
134+
measure=measure,
135+
resampling=Holdout(),
136+
acceleration=accel,
137+
verbosity=0)
138+
end
139+
ms = map(e->e.measurement, es)
140+
m = first(ms)
141+
@assert all((m), collect(ms)[2:end]) ERR_INCONSISTENT_RESULTS
142+
return first(es)
133143
end
134144
end
135145

136-
function tuned_pipe_evaluation(measure, model, data...; throw=false, verbosity=1)
146+
function tuned_pipe_evaluation(
147+
measure,
148+
model,
149+
data...;
150+
throw=false,
151+
verbosity=1,
152+
)
137153
message = "[:tuned_pipe_evaluation] Evaluating perfomance in a tuned pipeline "
138154
attempt(finalize(message, verbosity); throw) do
139155
pipe = identity |> model
140-
tuned_pipe = TunedModel(models=[pipe,],
141-
measure=measure)
142-
evaluate(tuned_pipe, data...;
143-
measure=measure,
144-
verbosity=0);
156+
tuned_pipe = TunedModel(
157+
models=[pipe,],
158+
measure=measure,
159+
)
160+
evaluate(
161+
tuned_pipe, data...;
162+
measure=measure,
163+
verbosity=0,
164+
)
145165
end
146166
end
147167

148168
function ensemble_prediction(model, data...; throw=false, verbosity=1)
149169
attempt(finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
150-
imodel = EnsembleModel(model=model,
151-
n=2)
170+
imodel = EnsembleModel(
171+
model=model,
172+
n=2,
173+
)
152174
mach = machine(imodel, data...)
153175
fit!(mach, verbosity=0)
154176
predict(mach, first(data))

src/test.jl

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ with automatic code loading" below.
2727
2828
The extent of testing is controlled by `level`:
2929
30-
|`level` | description | tests (full list below) |
31-
|:----------------|:---------------------------------|:------------------------|
32-
| 1 | test code loading | `:model_type` |
33-
| 2 (default) | basic test of model interface | first four tests |
34-
| 3 | comprehensive | all applicable tests |
30+
|`level` | description | tests (full list below) |
31+
|:----------------|:----------------------------------|:------------------------|
32+
| 1 | test code loading | `:model_type` |
33+
| 2 (default) | basic test of model interface | first four tests |
34+
| 3 | comprehensive CPU1() | all CPU1() tests |
35+
| 4 | comprehensive CPU1()/CPUThreads() | all tests |
3536
3637
By default, exceptions caught in tests are not thrown. If
3738
`throw=true`, testing will terminate at the first execption
@@ -131,6 +132,10 @@ These additional tests are applied to `Supervised` models:
131132
(metric), evaluate the performance of the model using `evaluate!`
132133
and a `Holdout` set.
133134
135+
- `:accelerated_evaluation`: Assuming the model appears to make
136+
repeatable predictions on retraining, repeat the `:evaluation` test
137+
using `CPUThreads()` acceleration and check agreement with `CPU1()` case.
138+
134139
- `:tuned_pipe_evaluation`: Repeat the `:evauation` test but first
135140
insert model in a pipeline with a trivial pre-processing step
136141
(applies the identity transformation) and wrap in `TunedModel` (only
@@ -156,11 +161,12 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
156161
:fitted_machine,
157162
:operations,
158163
:evaluation,
164+
:accelerated_evaluation,
159165
:tuned_pipe_evaluation,
160166
:threshold_prediction,
161167
:ensemble_prediction,
162168
:iteration_prediction
163-
), NTuple{11, String}}}(undef, nproxies)
169+
), NTuple{12, String}}}(undef, nproxies)
164170

165171
# summary table row corresponding to all tests skipped:
166172
row0 = (
@@ -171,6 +177,7 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
171177
fitted_machine = "-",
172178
operations = "-",
173179
evaluation = "-",
180+
accelerated_evaluation = "-",
174181
tuned_pipe_evaluation = "-",
175182
threshold_prediction = "-",
176183
ensemble_prediction = "-",
@@ -269,10 +276,56 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
269276

270277
# evaluation:
271278
evaluation, outcome =
272-
MLJTestIntegration.evaluation(measure, model_instance, data...; throw, verbosity)
279+
MLJTestIntegration.evaluation(
280+
measure,
281+
model_instance,
282+
[CPU1(),],
283+
data...;
284+
throw,
285+
verbosity,
286+
)
273287
row = update(row, i, :evaluation, evaluation, outcome)
274288
outcome == "×" && continue
275289

290+
# determine computational resources to test; we only test more
291+
# than CPU1() if model evaluations are independent of training
292+
# run (assuming this means models are "deterministic", ie,
293+
# RNGs):
294+
resources = MLJ.AbstractResource[] # fallback
295+
if level > 3
296+
per_fold = evaluation.per_fold[1]
297+
per_folds = map(1:(N_MODELS_FOR_REPEATABILITY_TEST - 1)) do _
298+
e, o = MLJTestIntegration.evaluation(
299+
measure,
300+
model_instance,
301+
[CPU1(),],
302+
data...;
303+
throw=false,
304+
verbosity,
305+
)
306+
o == "" || return nothing
307+
e.per_fold[1]
308+
end
309+
if all((per_fold), per_folds)
310+
resources = RESOURCES
311+
end
312+
end
313+
314+
if length(resources) > 1
315+
# accelerated_evaluation:
316+
evaluation, outcome =
317+
MLJTestIntegration.evaluation(
318+
measure,
319+
model_instance,
320+
resources,
321+
data...;
322+
throw,
323+
verbosity,
324+
)
325+
row = update(row, i, :accelerated_evaluation, evaluation, outcome)
326+
outcome == "×" && continue
327+
end
328+
276329
# tuned_pipe_evaluation:
277330
tuned_pipe_evaluation, outcome =
278331
MLJTestIntegration.tuned_pipe_evaluation(
@@ -287,15 +340,26 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
287340

288341
# ensemble_prediction:
289342
ensemble_prediction, outcome =
290-
MLJTestIntegration.ensemble_prediction(model_instance, data...; throw, verbosity)
343+
MLJTestIntegration.ensemble_prediction(
344+
model_instance,
345+
data...;
346+
throw,
347+
verbosity,
348+
)
291349
row = update(row, i, :ensemble_prediction, ensemble_prediction, outcome)
292350
outcome == "×" && continue
293351

294352
isnothing(iteration_parameter(model_instance)) && continue
295353

296354
# iteration prediction:
297355
iteration_prediction, outcome =
298-
MLJTestIntegration.iteration_prediction(measure, model_instance, data...; throw, verbosity)
356+
MLJTestIntegration.iteration_prediction(
357+
measure,
358+
model_instance,
359+
data...;
360+
throw,
361+
verbosity,
362+
)
299363
row = update(row, i, :iteration_prediction, iteration_prediction, outcome)
300364
outcome == "×" && continue
301365
end

test/attemptors.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
good() = 42
55

66
@test (@test_logs MLJTestIntegration.attempt(bad, "")) == (e, "×")
7-
@test (@test_logs (:info, "look ×") MLJTestIntegration.attempt(bad, "look ")) == (e, "×")
7+
@test(@test_logs(
8+
(:info, "look ×"),
9+
MLJTestIntegration.attempt(bad, "look "),
10+
) == (e, "×"))
811
@test (@test_logs MLJTestIntegration.attempt(good, "")) == (42, "")
9-
@test (@test_logs (:info, "look ✓") MLJTestIntegration.attempt(good, "look ")) == (42, "")
12+
@test (@test_logs(
13+
(:info, "look ✓"),
14+
MLJTestIntegration.attempt(good, "look "),
15+
) == (42, ""))
1016
@test_throws e MLJTestIntegration.attempt(bad, ""; throw=true)
1117
end
1218

test/test.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ expected_summary1 = (
1111
fitted_machine = "",
1212
operations = "predict",
1313
evaluation = "",
14+
accelerated_evaluation = "",
1415
tuned_pipe_evaluation = "",
1516
threshold_prediction = "",
1617
ensemble_prediction = "",
@@ -25,6 +26,7 @@ expected_summary2 = (
2526
fitted_machine = "",
2627
operations = "predict",
2728
evaluation = "",
29+
accelerated_evaluation = "",
2830
tuned_pipe_evaluation = "",
2931
threshold_prediction = "-",
3032
ensemble_prediction = "",
@@ -41,7 +43,7 @@ expected_summary2 = (
4143
X,
4244
y;
4345
mod=@__MODULE__,
44-
level=3,
46+
level=4,
4547
verbosity=0
4648
)
4749
@test isempty(fails)
@@ -61,7 +63,7 @@ end
6163
X,
6264
y;
6365
mod=@__MODULE__,
64-
level=3,
66+
level=4,
6567
verbosity=0
6668
)
6769
@test isempty(fails)
@@ -109,6 +111,7 @@ end
109111
fitted_machine = "×",
110112
operations = "-",
111113
evaluation = "-",
114+
accelerated_evaluation = "-",
112115
tuned_pipe_evaluation = "-",
113116
threshold_prediction = "-",
114117
ensemble_prediction = "-",
@@ -123,6 +126,7 @@ end
123126
fitted_machine = "",
124127
operations = "predict",
125128
evaluation = "×",
129+
accelerated_evaluation = "-",
126130
tuned_pipe_evaluation = "-",
127131
threshold_prediction = "-",
128132
ensemble_prediction = "-",
@@ -201,6 +205,7 @@ end
201205
fitted_machine = "-",
202206
operations = "-",
203207
evaluation = "-",
208+
accelerated_evaluation = "-",
204209
tuned_pipe_evaluation = "-",
205210
threshold_prediction = "-",
206211
ensemble_prediction = "-",
@@ -225,11 +230,37 @@ end
225230
fitted_machine = "",
226231
operations = "predict",
227232
evaluation = "-",
233+
accelerated_evaluation = "-",
228234
tuned_pipe_evaluation = "-",
229235
threshold_prediction = "-",
230236
ensemble_prediction = "-",
231237
iteration_prediction = "-",
232238
)
239+
240+
# level=4:
241+
fails, summary =
242+
@test_logs MLJTestIntegration.test(
243+
classifiers,
244+
X,
245+
y;
246+
mod=@__MODULE__,
247+
level=4,
248+
verbosity=0)
249+
@test isempty(fails)
250+
@test summary[1] == (
251+
name = "ConstantClassifier",
252+
package_name = "MLJModels",
253+
model_type = "",
254+
model_instance = "",
255+
fitted_machine = "",
256+
operations = "predict",
257+
evaluation = "",
258+
accelerated_evaluation = "",
259+
tuned_pipe_evaluation = "",
260+
threshold_prediction = "",
261+
ensemble_prediction = "",
262+
iteration_prediction = "-",
263+
)
233264
end
234265

235266
@testset "iterative model" begin
@@ -252,8 +283,10 @@ end
252283
fitted_machine = "",
253284
operations = "predict",
254285
evaluation = "",
286+
accelerated_evaluation = "-",
255287
tuned_pipe_evaluation = "",
256288
threshold_prediction = "-",
257289
ensemble_prediction = "",
258-
iteration_prediction = "",)
290+
iteration_prediction = "",
291+
)
259292
end

0 commit comments

Comments
 (0)