Skip to content

Commit 7c21e5b

Browse files
authored
Merge pull request #28 from JuliaAI/dev
For a 0.3.0 release
2 parents f1e2bf1 + c5fd43d commit 7c21e5b

File tree

7 files changed

+28
-165
lines changed

7 files changed

+28
-165
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
name = "MLJTestIntegration"
22
uuid = "697918b4-fdc1-4f9e-8ff9-929724cee270"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.2.5"
4+
version = "0.3.0"
55

66
[deps]
77
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
8+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
89
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
910
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1112

1213
[compat]
13-
MLJ = "0.18"
14+
MLJ = "0.18, 0.19, 0.20"
1415
NearestNeighborModels = "0.2"
1516
julia = "1.6"

README.md

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Package for applying integration tests to models implementing the
44
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) model
55
interface.
66

7+
**To test implementations of the MLJ model interface, use [MLJTestInterface.jl](https://github.com/JuliaAI/MLJTestInterface.jl)
8+
instead.**
9+
710
[![Lifecycle:Experimental](https://img.shields.io/badge/Lifecycle-Experimental-339999)](https://github.com/bcgov/repomountie/blob/master/doc/lifecycle-badges.md) [![Build Status](https://github.com/JuliaAI/MLJTestIntegration.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/MLJTestIntegration.jl/actions) [![Coverage](https://codecov.io/gh/JuliaAI/MLJTestIntegration.jl/branch/master/graph/badge.svg)](https://codecov.io/github/JuliaAI/MLJTestIntegration.jl?branch=master)
811

912
# Installation
@@ -37,22 +40,7 @@ Query the document strings for details, or see
3740
[examples/bigtest/notebook.jl](examples/bigtest/notebook.jl).
3841

3942

40-
# Examples
41-
42-
## Testing models in a new MLJ model interface implementation
43-
44-
The following tests the model interface implemented by some model type `MyClassifier` for
45-
multiclass classification, as might appear in tests for a package providing that type:
46-
47-
```julia
48-
import MLJTestIntegration
49-
using Test
50-
X, y = MLJTestIntegration.make_multiclass()
51-
failures, summary = MLJTestIntegration.test([MyClassifier, ], X, y, verbosity=1, mod=@__MODULE__)
52-
@test isempty(failures)
53-
```
54-
55-
## Testing models after filtering models in the registry
43+
# Example: Testing models filtered from the MLJ model registry
5644

5745
The following applies comprehensive integration tests to all
5846
regressors provided by the package GLM.jl appearing in the MLJ Model
@@ -69,7 +57,9 @@ regressors = MLJTestIntegration.MLJ.models(matching(X, y)) do m
6957
end
7058

7159
# to test code loading:
72-
MLJTestIntegration.test(regressors, X, y, verbosity=2, mod=@__MODULE__, level=1)
60+
failures, summary =
61+
MLJTestIntegration.test(regressors, X, y, verbosity=2, mod=@__MODULE__, level=1)
62+
@assert isempty(failures)
7363

7464
# comprehensive tests:
7565
failures, summary =

examples/bigtest/Manifest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ version = "0.7.1"
956956

957957
[[deps.MLJModelInterface]]
958958
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
959-
path = "/Users/anthony/MLJ/MLJModelInterface"
959+
git-tree-sha1 = "4040c0da2bd05130687cc258c1318acd32bace90"
960960
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
961961
version = "1.7.1"
962962

src/MLJTestIntegration.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Pkg
77
using .Threads
88
using Test
99
using NearestNeighborModels
10+
import MLJTestInterface
11+
const MTI = MLJTestInterface
12+
import MLJTestInterface.attempt
1013

1114
include("attemptors.jl")
1215
include("test.jl")

src/attemptors.jl

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,15 @@
11
const ERR_INCONSISTENT_RESULTS =
22
"Different computational resources are giving different results. "
33

4-
"""
5-
attempt(f, message; throw=false)
64

7-
Return `(f(), "✓") if `f()` executes without throwing an
8-
exception. Otherwise, return `(ex, "×"), where `ex` is the exception
9-
caught. Only truly throw the exception if `throw=true`.
10-
11-
If `message` is not empty, then it is logged to `Info`, together with
12-
the second return value ("✓" or "×").
13-
14-
15-
"""
16-
function attempt(f, message; throw=false)
17-
ret = try
18-
(f(), "")
19-
catch ex
20-
throw && Base.throw(ex)
21-
(ex, "×")
22-
end
23-
isempty(message) || @info message*last(ret)
24-
return ret
25-
end
26-
27-
finalize(message, verbosity) = verbosity < 2 ? "" : message
28-
29-
30-
# # ATTEMPTORS
31-
32-
# TODO: Instead, in ****** below, use `MLJ.load_path`, after MLJModels
33-
# is updated to 0.16. And delete the two methods immediately
34-
# following. What's required will already be in MLJModels 0.15.10, but
35-
# the current implementation avoids an explicit MLJModels dependency
36-
# for MLJTestIntegration.
37-
load_path(model_type) = MLJ.load_path(model_type)
38-
function load_path(proxy::NamedTuple)
39-
handle = (name=proxy.name, pkg=proxy.package_name)
40-
return MLJ.MLJModels.INFO_GIVEN_HANDLE[handle][:load_path]
41-
end
5+
# # NEW ATTEMPTORS
426

437
root(load_path) = split(load_path, '.') |> first
448

45-
function model_type(proxy, mod; throw=false, verbosity=1)
46-
# check interface package really is in current environment:
47-
message = "[:model_type] Loading model type "
48-
model_type, outcome = attempt(finalize(message, verbosity); throw) do
49-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
50-
load_path_ex = load_path |> Meta.parse
51-
api_pkg_ex = root(load_path) |> Symbol
52-
import_ex = :(import $api_pkg_ex)
53-
quote
54-
$import_ex
55-
$load_path_ex
56-
end |> mod.eval
57-
end
58-
59-
# catch case of interface package not in current environment:
60-
if outcome == "×" && model_type isa ArgumentError
61-
# try to get the name of interface package; if this fails we
62-
# catch the exception thrown but take no further
63-
# action. Otherwise, we test if the original exception caught
64-
# above, `model_type`, was triggered because of API package is
65-
# missing from in environment.
66-
api_pkg = try
67-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
68-
api_pkg = root(load_path)
69-
catch
70-
nothing
71-
end
72-
if !isnothing(api_pkg) &&
73-
api_pkg != "unknown" &&
74-
contains(model_type.msg, "$api_pkg not found in")
75-
Base.throw(model_type)
76-
end
77-
end
78-
79-
return model_type, outcome
80-
end
81-
82-
function model_instance(model_type; throw=false, verbosity=1)
83-
message = "[:model_instance] Instantiating default model "
84-
attempt(finalize(message, verbosity); throw) do
85-
model_type()
86-
end
87-
end
88-
89-
function fitted_machine(model, data...; throw=false, verbosity=1)
90-
message = "[:fitted_machine] Fitting machine "
91-
attempt(finalize(message, verbosity); throw) do
92-
mach = model isa Static ? machine(model) :
93-
machine(model, data...)
94-
fit!(mach, verbosity=-1)
95-
MLJ.report(mach)
96-
MLJ.fitted_params(mach)
97-
mach
98-
end
99-
end
100-
101-
function operations(fitted_machine, data...; throw=false, verbosity=1)
102-
message = "[:operations] Calling `predict`, `transform` and/or `inverse_transform` "
103-
attempt(finalize(message, verbosity); throw) do
104-
operations = String[]
105-
methods = MLJ.implemented_methods(fitted_machine.model)
106-
if :predict in methods
107-
predict(fitted_machine, first(data))
108-
push!(operations, "predict")
109-
end
110-
if :transform in methods
111-
W = transform(fitted_machine, first(data))
112-
push!(operations, "transform")
113-
if :inverse_transform in methods
114-
inverse_transform(fitted_machine, W)
115-
push!(operations, "inverse_transform")
116-
end
117-
end
118-
join(operations, ", ")
119-
end
120-
end
121-
1229
function threshold_prediction(model, data...; throw=false, verbosity=1)
12310
message = "[:threshold_predictor] Calling fit!/predict for threshold predictor "*
12411
"test) "
125-
attempt(finalize(message, verbosity); throw) do
12+
attempt(MTI.finalize(message, verbosity); throw) do
12613
tmodel = BinaryThresholdPredictor(model)
12714
mach = machine(tmodel, data...)
12815
fit!(mach, verbosity=0)
@@ -134,7 +21,7 @@ function evaluation(measure, model, resources, data...; throw=false, verbosity=1
13421
L = length(resources)
13522
message = L > 1 ? "[:accelerated_evaluation] " : "[evaluation] "
13623
message *= "Evaluating model performance using $L different resources. "
137-
attempt(finalize(message, verbosity); throw) do
24+
attempt(MTI.finalize(message, verbosity); throw) do
13825
es = map(resources) do resource
13926
evaluate(model, data...;
14027
measure=measure,
@@ -157,7 +44,7 @@ function tuned_pipe_evaluation(
15744
verbosity=1,
15845
)
15946
message = "[:tuned_pipe_evaluation] Evaluating perfomance in a tuned pipeline "
160-
attempt(finalize(message, verbosity); throw) do
47+
attempt(MTI.finalize(message, verbosity); throw) do
16148
pipe = identity |> model
16249
tuned_pipe = TunedModel(
16350
models=fill(pipe, 3),
@@ -172,7 +59,7 @@ function tuned_pipe_evaluation(
17259
end
17360

17461
function ensemble_prediction(model, data...; throw=false, verbosity=1)
175-
attempt(finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
62+
attempt(MTI.finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
17663
imodel = EnsembleModel(
17764
model=model,
17865
n=2,
@@ -186,7 +73,7 @@ end
18673
# the `model` must support iteration (`!isnothing(iteration_paramater(model))`)
18774
function iteration_prediction(measure, model, data...; throw=false, verbosity=1)
18875
message = "[:iteration_prediction] Iterating with controls "
189-
attempt(finalize(message, verbosity); throw) do
76+
attempt(MTI.finalize(message, verbosity); throw) do
19077
imodel = IteratedModel(model=model,
19178
measure=measure,
19279
controls=[Step(1),
@@ -240,7 +127,7 @@ function stack_evaluation(
240127
isregressor = AbstractVector{Continuous} <: target_scitype
241128
measure = isregressor ? LPLoss(2) : BrierScore()
242129

243-
attempt(finalize(message, verbosity); throw) do
130+
attempt(MTI.finalize(message, verbosity); throw) do
244131
es = map(resources) do resource
245132
stack = _stack(model, resource, isregressor)
246133
evaluate(

src/test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,27 +253,27 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
253253
row = merge(row0, (; name, package_name))
254254

255255
# [model_type]:
256-
model_type, outcome = MLJTestIntegration.model_type(model_proxy, mod; throw, verbosity)
256+
model_type, outcome = MTI.model_type(model_proxy, mod; throw, verbosity)
257257
row = update(row, i, :model_type, model_type, outcome)
258258
outcome == "×" && continue
259259

260260
level > 1 || continue
261261

262262
# [model_instance]:
263263
model_instance, outcome =
264-
MLJTestIntegration.model_instance(model_type; throw, verbosity)
264+
MTI.model_instance(model_type; throw, verbosity)
265265
row = update(row, i, :model_instance, model_instance, outcome)
266266
outcome == "×" && continue
267267

268268
# [fitted_machine]:
269269
fitted_machine, outcome =
270-
MLJTestIntegration.fitted_machine(model_instance, data...; throw, verbosity)
270+
MTI.fitted_machine(model_instance, data...; throw, verbosity)
271271
row = update(row, i, :fitted_machine, fitted_machine, outcome)
272272
outcome == "×" && continue
273273

274274
# [operations]:
275275
operations, outcome =
276-
MLJTestIntegration.operations(fitted_machine, data...; throw, verbosity)
276+
MTI.operations(fitted_machine, data...; throw, verbosity)
277277
# special treatment to get list of operations in `summary`:
278278
if outcome == "×"
279279
row = update(row, i, :operations, operations, outcome)

test/attemptors.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,12 @@
1-
@testset "attempt()" begin
2-
e = ArgumentError("elephant")
3-
bad() = throw(e)
4-
good() = 42
5-
6-
@test (@test_logs MLJTestIntegration.attempt(bad, "")) == (e, "×")
7-
@test(@test_logs(
8-
(:info, "look ×"),
9-
MLJTestIntegration.attempt(bad, "look "),
10-
) == (e, "×"))
11-
@test (@test_logs MLJTestIntegration.attempt(good, "")) == (42, "")
12-
@test (@test_logs(
13-
(:info, "look ✓"),
14-
MLJTestIntegration.attempt(good, "look "),
15-
) == (42, ""))
16-
@test_throws e MLJTestIntegration.attempt(bad, ""; throw=true)
17-
end
18-
19-
@testset "model_type" begin
1+
@testset "model_type with model proxy instead of type" begin
202

213
# test error thrown (not caught) if pkg missing from environment:
22-
@test_throws ArgumentError MLJTestIntegration.model_type(
4+
@test_throws ArgumentError MLJTestIntegration.MLJTestInterface.model_type(
235
(name="PCA", package_name="MultivariateStats"),
246
@__MODULE__
257
)
268

27-
M, outcome = MLJTestIntegration.model_type(
9+
M, outcome = MLJTestIntegration.MLJTestInterface.model_type(
2810
(name="DecisionTreeClassifier", package_name="DecisionTree"),
2911
@__MODULE__;
3012
verbosity=0

0 commit comments

Comments
 (0)