Skip to content

Commit 195dc94

Browse files
authored
Merge pull request #27 from JuliaAI/refactor-using-mljtestinterface
Refactor using MLJTestInterface.jl as foundation
2 parents 3afc7bb + ca4c9a4 commit 195dc94

File tree

6 files changed

+23
-162
lines changed

6 files changed

+23
-162
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ version = "0.2.6"
55

66
[deps]
77
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
8+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
89
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
910
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

1213
[compat]
13-
MLJ = "0.18, 0.19"
14+
MLJ = "0.18, 0.19, 0.20"
1415
NearestNeighborModels = "0.2"
1516
julia = "1.6"

README.md

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Package for applying integration tests to models implementing the
44
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) model
55
interface.
66

7+
**To test implementations of the MLJ model interface, use [MLJTestInterface.jl]()
8+
instead.**
9+
710
[![Lifecycle:Experimental](https://img.shields.io/badge/Lifecycle-Experimental-339999)](https://github.com/bcgov/repomountie/blob/master/doc/lifecycle-badges.md) [![Build Status](https://github.com/JuliaAI/MLJTestIntegration.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/MLJTestIntegration.jl/actions) [![Coverage](https://codecov.io/gh/JuliaAI/MLJTestIntegration.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/MLJTestIntegration.jl?branch=master)
811

912
# Installation
@@ -37,22 +40,7 @@ Query the document strings for details, or see
3740
[examples/bigtest/notebook.jl](examples/bigtest/notebook.jl).
3841

3942

40-
# Examples
41-
42-
## Testing models in a new MLJ model interface implementation
43-
44-
The following tests the model interface implemented by some model type `MyClassifier` for
45-
multiclass classification, as might appear in tests for a package providing that type:
46-
47-
```julia
48-
import MLJTestIntegration
49-
using Test
50-
X, y = MLJTestIntegration.make_multiclass()
51-
failures, summary = MLJTestIntegration.test([MyClassifier, ], X, y, verbosity=1, mod=@__MODULE__)
52-
@test isempty(failures)
53-
```
54-
55-
## Testing models after filtering models in the registry
43+
# Example: Testing models filtered from the MLJ model registry
5644

5745
The following applies comprehensive integration tests to all
5846
regressors provided by the package GLM.jl appearing in the MLJ Model

src/MLJTestIntegration.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Pkg
77
using .Threads
88
using Test
99
using NearestNeighborModels
10+
import MLJTestInterface
11+
const MTI = MLJTestInterface
12+
import MLJTestInterface.attempt
1013

1114
include("attemptors.jl")
1215
include("test.jl")

src/attemptors.jl

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,15 @@
11
const ERR_INCONSISTENT_RESULTS =
22
"Different computational resources are giving different results. "
33

4-
"""
5-
attempt(f, message; throw=false)
64

7-
Return `(f(), "✓") if `f()` executes without throwing an
8-
exception. Otherwise, return `(ex, "×"), where `ex` is the exception
9-
caught. Only truly throw the exception if `throw=true`.
10-
11-
If `message` is not empty, then it is logged to `Info`, together with
12-
the second return value ("✓" or "×").
13-
14-
15-
"""
16-
function attempt(f, message; throw=false)
17-
ret = try
18-
(f(), "")
19-
catch ex
20-
throw && Base.throw(ex)
21-
(ex, "×")
22-
end
23-
isempty(message) || @info message*last(ret)
24-
return ret
25-
end
26-
27-
finalize(message, verbosity) = verbosity < 2 ? "" : message
28-
29-
30-
# # ATTEMPTORS
31-
32-
# TODO: Instead, in ****** below, use `MLJ.load_path`, after MLJModels
33-
# is updated to 0.16. And delete the two methods immediately
34-
# following. What's required will already be in MLJModels 0.15.10, but
35-
# the current implementation avoids an explicit MLJModels dependency
36-
# for MLJTestIntegration.
37-
load_path(model_type) = MLJ.load_path(model_type)
38-
function load_path(proxy::NamedTuple)
39-
handle = (name=proxy.name, pkg=proxy.package_name)
40-
return MLJ.MLJModels.INFO_GIVEN_HANDLE[handle][:load_path]
41-
end
5+
# # NEW ATTEMPTORS
426

437
root(load_path) = split(load_path, '.') |> first
448

45-
function model_type(proxy, mod; throw=false, verbosity=1)
46-
# check interface package really is in current environment:
47-
message = "[:model_type] Loading model type "
48-
model_type, outcome = attempt(finalize(message, verbosity); throw) do
49-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
50-
load_path_ex = load_path |> Meta.parse
51-
api_pkg_ex = root(load_path) |> Symbol
52-
import_ex = :(import $api_pkg_ex)
53-
quote
54-
$import_ex
55-
$load_path_ex
56-
end |> mod.eval
57-
end
58-
59-
# catch case of interface package not in current environment:
60-
if outcome == "×" && model_type isa ArgumentError
61-
# try to get the name of interface package; if this fails we
62-
# catch the exception thrown but take no further
63-
# action. Otherwise, we test if the original exception caught
64-
# above, `model_type`, was triggered because of API package is
65-
# missing from in environment.
66-
api_pkg = try
67-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
68-
api_pkg = root(load_path)
69-
catch
70-
nothing
71-
end
72-
if !isnothing(api_pkg) &&
73-
api_pkg != "unknown" &&
74-
contains(model_type.msg, "$api_pkg not found in")
75-
Base.throw(model_type)
76-
end
77-
end
78-
79-
return model_type, outcome
80-
end
81-
82-
function model_instance(model_type; throw=false, verbosity=1)
83-
message = "[:model_instance] Instantiating default model "
84-
attempt(finalize(message, verbosity); throw) do
85-
model_type()
86-
end
87-
end
88-
89-
function fitted_machine(model, data...; throw=false, verbosity=1)
90-
message = "[:fitted_machine] Fitting machine "
91-
attempt(finalize(message, verbosity); throw) do
92-
mach = model isa Static ? machine(model) :
93-
machine(model, data...)
94-
fit!(mach, verbosity=-1)
95-
MLJ.report(mach)
96-
MLJ.fitted_params(mach)
97-
mach
98-
end
99-
end
100-
101-
function operations(fitted_machine, data...; throw=false, verbosity=1)
102-
message = "[:operations] Calling `predict`, `transform` and/or `inverse_transform` "
103-
attempt(finalize(message, verbosity); throw) do
104-
operations = String[]
105-
methods = MLJ.implemented_methods(fitted_machine.model)
106-
if :predict in methods
107-
predict(fitted_machine, first(data))
108-
push!(operations, "predict")
109-
end
110-
if :transform in methods
111-
W = transform(fitted_machine, first(data))
112-
push!(operations, "transform")
113-
if :inverse_transform in methods
114-
inverse_transform(fitted_machine, W)
115-
push!(operations, "inverse_transform")
116-
end
117-
end
118-
join(operations, ", ")
119-
end
120-
end
121-
1229
function threshold_prediction(model, data...; throw=false, verbosity=1)
12310
message = "[:threshold_predictor] Calling fit!/predict for threshold predictor "*
12411
"test) "
125-
attempt(finalize(message, verbosity); throw) do
12+
attempt(MTI.finalize(message, verbosity); throw) do
12613
tmodel = BinaryThresholdPredictor(model)
12714
mach = machine(tmodel, data...)
12815
fit!(mach, verbosity=0)
@@ -134,7 +21,7 @@ function evaluation(measure, model, resources, data...; throw=false, verbosity=1
13421
L = length(resources)
13522
message = L > 1 ? "[:accelerated_evaluation] " : "[evaluation] "
13623
message *= "Evaluating model performance using $L different resources. "
137-
attempt(finalize(message, verbosity); throw) do
24+
attempt(MTI.finalize(message, verbosity); throw) do
13825
es = map(resources) do resource
13926
evaluate(model, data...;
14027
measure=measure,
@@ -157,7 +44,7 @@ function tuned_pipe_evaluation(
15744
verbosity=1,
15845
)
15946
message = "[:tuned_pipe_evaluation] Evaluating perfomance in a tuned pipeline "
160-
attempt(finalize(message, verbosity); throw) do
47+
attempt(MTI.finalize(message, verbosity); throw) do
16148
pipe = identity |> model
16249
tuned_pipe = TunedModel(
16350
models=fill(pipe, 3),
@@ -172,7 +59,7 @@ function tuned_pipe_evaluation(
17259
end
17360

17461
function ensemble_prediction(model, data...; throw=false, verbosity=1)
175-
attempt(finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
62+
attempt(MTI.finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
17663
imodel = EnsembleModel(
17764
model=model,
17865
n=2,
@@ -186,7 +73,7 @@ end
18673
# the `model` must support iteration (`!isnothing(iteration_paramater(model))`)
18774
function iteration_prediction(measure, model, data...; throw=false, verbosity=1)
18875
message = "[:iteration_prediction] Iterating with controls "
189-
attempt(finalize(message, verbosity); throw) do
76+
attempt(MTI.finalize(message, verbosity); throw) do
19077
imodel = IteratedModel(model=model,
19178
measure=measure,
19279
controls=[Step(1),
@@ -240,7 +127,7 @@ function stack_evaluation(
240127
isregressor = AbstractVector{Continuous} <: target_scitype
241128
measure = isregressor ? LPLoss(2) : BrierScore()
242129

243-
attempt(finalize(message, verbosity); throw) do
130+
attempt(MTI.finalize(message, verbosity); throw) do
244131
es = map(resources) do resource
245132
stack = _stack(model, resource, isregressor)
246133
evaluate(

src/test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,27 +253,27 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
253253
row = merge(row0, (; name, package_name))
254254

255255
# [model_type]:
256-
model_type, outcome = MLJTestIntegration.model_type(model_proxy, mod; throw, verbosity)
256+
model_type, outcome = MTI.model_type(model_proxy, mod; throw, verbosity)
257257
row = update(row, i, :model_type, model_type, outcome)
258258
outcome == "×" && continue
259259

260260
level > 1 || continue
261261

262262
# [model_instance]:
263263
model_instance, outcome =
264-
MLJTestIntegration.model_instance(model_type; throw, verbosity)
264+
MTI.model_instance(model_type; throw, verbosity)
265265
row = update(row, i, :model_instance, model_instance, outcome)
266266
outcome == "×" && continue
267267

268268
# [fitted_machine]:
269269
fitted_machine, outcome =
270-
MLJTestIntegration.fitted_machine(model_instance, data...; throw, verbosity)
270+
MTI.fitted_machine(model_instance, data...; throw, verbosity)
271271
row = update(row, i, :fitted_machine, fitted_machine, outcome)
272272
outcome == "×" && continue
273273

274274
# [operations]:
275275
operations, outcome =
276-
MLJTestIntegration.operations(fitted_machine, data...; throw, verbosity)
276+
MTI.operations(fitted_machine, data...; throw, verbosity)
277277
# special treatment to get list of operations in `summary`:
278278
if outcome == "×"
279279
row = update(row, i, :operations, operations, outcome)

test/attemptors.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,12 @@
1-
@testset "attempt()" begin
2-
e = ArgumentError("elephant")
3-
bad() = throw(e)
4-
good() = 42
5-
6-
@test (@test_logs MLJTestIntegration.attempt(bad, "")) == (e, "×")
7-
@test(@test_logs(
8-
(:info, "look ×"),
9-
MLJTestIntegration.attempt(bad, "look "),
10-
) == (e, "×"))
11-
@test (@test_logs MLJTestIntegration.attempt(good, "")) == (42, "")
12-
@test (@test_logs(
13-
(:info, "look ✓"),
14-
MLJTestIntegration.attempt(good, "look "),
15-
) == (42, ""))
16-
@test_throws e MLJTestIntegration.attempt(bad, ""; throw=true)
17-
end
18-
19-
@testset "model_type" begin
1+
@testset "model_type with model proxy instead of type" begin
202

213
# test error thrown (not caught) if pkg missing from environment:
22-
@test_throws ArgumentError MLJTestIntegration.model_type(
4+
@test_throws ArgumentError MLJTestIntegration.MLJTestInterface.model_type(
235
(name="PCA", package_name="MultivariateStats"),
246
@__MODULE__
257
)
268

27-
M, outcome = MLJTestIntegration.model_type(
9+
M, outcome = MLJTestIntegration.MLJTestInterface.model_type(
2810
(name="DecisionTreeClassifier", package_name="DecisionTree"),
2911
@__MODULE__;
3012
verbosity=0

0 commit comments

Comments
 (0)