Skip to content

Commit d5aec87

Browse files
committed
refactor, using the new MLJTestInterface.jl as foundation
1 parent 3afc7bb commit d5aec87

File tree

5 files changed

+18
-145
lines changed

5 files changed

+18
-145
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.6"
55

66
[deps]
77
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
8+
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
89
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
910
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/MLJTestIntegration.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Pkg
77
using .Threads
88
using Test
99
using NearestNeighborModels
10+
import MLJTestInterface
11+
const MTI = MLJTestInterface
12+
import MLJTestInterface.attempt
1013

1114
include("attemptors.jl")
1215
include("test.jl")

src/attemptors.jl

Lines changed: 7 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,15 @@
11
const ERR_INCONSISTENT_RESULTS =
22
"Different computational resources are giving different results. "
33

4-
"""
5-
attempt(f, message; throw=false)
64

7-
Return `(f(), "✓") if `f()` executes without throwing an
8-
exception. Otherwise, return `(ex, "×"), where `ex` is the exception
9-
caught. Only truly throw the exception if `throw=true`.
10-
11-
If `message` is not empty, then it is logged to `Info`, together with
12-
the second return value ("✓" or "×").
13-
14-
15-
"""
16-
function attempt(f, message; throw=false)
17-
ret = try
18-
(f(), "")
19-
catch ex
20-
throw && Base.throw(ex)
21-
(ex, "×")
22-
end
23-
isempty(message) || @info message*last(ret)
24-
return ret
25-
end
26-
27-
finalize(message, verbosity) = verbosity < 2 ? "" : message
28-
29-
30-
# # ATTEMPTORS
31-
32-
# TODO: Instead, in ****** below, use `MLJ.load_path`, after MLJModels
33-
# is updated to 0.16. And delete the two methods immediately
34-
# following. What's required will already be in MLJModels 0.15.10, but
35-
# the current implementation avoids an explicit MLJModels dependency
36-
# for MLJTestIntegration.
37-
load_path(model_type) = MLJ.load_path(model_type)
38-
function load_path(proxy::NamedTuple)
39-
handle = (name=proxy.name, pkg=proxy.package_name)
40-
return MLJ.MLJModels.INFO_GIVEN_HANDLE[handle][:load_path]
41-
end
5+
# # NEW ATTEMPTORS
426

437
root(load_path) = split(load_path, '.') |> first
448

45-
function model_type(proxy, mod; throw=false, verbosity=1)
46-
# check interface package really is in current environment:
47-
message = "[:model_type] Loading model type "
48-
model_type, outcome = attempt(finalize(message, verbosity); throw) do
49-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
50-
load_path_ex = load_path |> Meta.parse
51-
api_pkg_ex = root(load_path) |> Symbol
52-
import_ex = :(import $api_pkg_ex)
53-
quote
54-
$import_ex
55-
$load_path_ex
56-
end |> mod.eval
57-
end
58-
59-
# catch case of interface package not in current environment:
60-
if outcome == "×" && model_type isa ArgumentError
61-
# try to get the name of interface package; if this fails we
62-
# catch the exception thrown but take no further
63-
# action. Otherwise, we test if the original exception caught
64-
# above, `model_type`, was triggered because of API package is
65-
# missing from in environment.
66-
api_pkg = try
67-
load_path = MLJTestIntegration.load_path(proxy) # MLJ.load_path(proxy) *****
68-
api_pkg = root(load_path)
69-
catch
70-
nothing
71-
end
72-
if !isnothing(api_pkg) &&
73-
api_pkg != "unknown" &&
74-
contains(model_type.msg, "$api_pkg not found in")
75-
Base.throw(model_type)
76-
end
77-
end
78-
79-
return model_type, outcome
80-
end
81-
82-
function model_instance(model_type; throw=false, verbosity=1)
83-
message = "[:model_instance] Instantiating default model "
84-
attempt(finalize(message, verbosity); throw) do
85-
model_type()
86-
end
87-
end
88-
89-
function fitted_machine(model, data...; throw=false, verbosity=1)
90-
message = "[:fitted_machine] Fitting machine "
91-
attempt(finalize(message, verbosity); throw) do
92-
mach = model isa Static ? machine(model) :
93-
machine(model, data...)
94-
fit!(mach, verbosity=-1)
95-
MLJ.report(mach)
96-
MLJ.fitted_params(mach)
97-
mach
98-
end
99-
end
100-
101-
function operations(fitted_machine, data...; throw=false, verbosity=1)
102-
message = "[:operations] Calling `predict`, `transform` and/or `inverse_transform` "
103-
attempt(finalize(message, verbosity); throw) do
104-
operations = String[]
105-
methods = MLJ.implemented_methods(fitted_machine.model)
106-
if :predict in methods
107-
predict(fitted_machine, first(data))
108-
push!(operations, "predict")
109-
end
110-
if :transform in methods
111-
W = transform(fitted_machine, first(data))
112-
push!(operations, "transform")
113-
if :inverse_transform in methods
114-
inverse_transform(fitted_machine, W)
115-
push!(operations, "inverse_transform")
116-
end
117-
end
118-
join(operations, ", ")
119-
end
120-
end
121-
1229
function threshold_prediction(model, data...; throw=false, verbosity=1)
12310
message = "[:threshold_predictor] Calling fit!/predict for threshold predictor "*
12411
"test) "
125-
attempt(finalize(message, verbosity); throw) do
12+
attempt(MTI.finalize(message, verbosity); throw) do
12613
tmodel = BinaryThresholdPredictor(model)
12714
mach = machine(tmodel, data...)
12815
fit!(mach, verbosity=0)
@@ -134,7 +21,7 @@ function evaluation(measure, model, resources, data...; throw=false, verbosity=1
13421
L = length(resources)
13522
message = L > 1 ? "[:accelerated_evaluation] " : "[evaluation] "
13623
message *= "Evaluating model performance using $L different resources. "
137-
attempt(finalize(message, verbosity); throw) do
24+
attempt(MTI.finalize(message, verbosity); throw) do
13825
es = map(resources) do resource
13926
evaluate(model, data...;
14027
measure=measure,
@@ -157,7 +44,7 @@ function tuned_pipe_evaluation(
15744
verbosity=1,
15845
)
15946
message = "[:tuned_pipe_evaluation] Evaluating perfomance in a tuned pipeline "
160-
attempt(finalize(message, verbosity); throw) do
47+
attempt(MTI.finalize(message, verbosity); throw) do
16148
pipe = identity |> model
16249
tuned_pipe = TunedModel(
16350
models=fill(pipe, 3),
@@ -172,7 +59,7 @@ function tuned_pipe_evaluation(
17259
end
17360

17461
function ensemble_prediction(model, data...; throw=false, verbosity=1)
175-
attempt(finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
62+
attempt(MTI.finalize("[:ensemble_prediction] Ensembling ", verbosity); throw) do
17663
imodel = EnsembleModel(
17764
model=model,
17865
n=2,
@@ -186,7 +73,7 @@ end
18673
# the `model` must support iteration (`!isnothing(iteration_paramater(model))`)
18774
function iteration_prediction(measure, model, data...; throw=false, verbosity=1)
18875
message = "[:iteration_prediction] Iterating with controls "
189-
attempt(finalize(message, verbosity); throw) do
76+
attempt(MTI.finalize(message, verbosity); throw) do
19077
imodel = IteratedModel(model=model,
19178
measure=measure,
19279
controls=[Step(1),
@@ -240,7 +127,7 @@ function stack_evaluation(
240127
isregressor = AbstractVector{Continuous} <: target_scitype
241128
measure = isregressor ? LPLoss(2) : BrierScore()
242129

243-
attempt(finalize(message, verbosity); throw) do
130+
attempt(MTI.finalize(message, verbosity); throw) do
244131
es = map(resources) do resource
245132
stack = _stack(model, resource, isregressor)
246133
evaluate(

src/test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,27 +253,27 @@ function test(model_proxies, data...; mod=Main, level=2, throw=false, verbosity=
253253
row = merge(row0, (; name, package_name))
254254

255255
# [model_type]:
256-
model_type, outcome = MLJTestIntegration.model_type(model_proxy, mod; throw, verbosity)
256+
model_type, outcome = MTI.model_type(model_proxy, mod; throw, verbosity)
257257
row = update(row, i, :model_type, model_type, outcome)
258258
outcome == "×" && continue
259259

260260
level > 1 || continue
261261

262262
# [model_instance]:
263263
model_instance, outcome =
264-
MLJTestIntegration.model_instance(model_type; throw, verbosity)
264+
MTI.model_instance(model_type; throw, verbosity)
265265
row = update(row, i, :model_instance, model_instance, outcome)
266266
outcome == "×" && continue
267267

268268
# [fitted_machine]:
269269
fitted_machine, outcome =
270-
MLJTestIntegration.fitted_machine(model_instance, data...; throw, verbosity)
270+
MTI.fitted_machine(model_instance, data...; throw, verbosity)
271271
row = update(row, i, :fitted_machine, fitted_machine, outcome)
272272
outcome == "×" && continue
273273

274274
# [operations]:
275275
operations, outcome =
276-
MLJTestIntegration.operations(fitted_machine, data...; throw, verbosity)
276+
MTI.operations(fitted_machine, data...; throw, verbosity)
277277
# special treatment to get list of operations in `summary`:
278278
if outcome == "×"
279279
row = update(row, i, :operations, operations, outcome)

test/attemptors.jl

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,12 @@
1-
@testset "attempt()" begin
2-
e = ArgumentError("elephant")
3-
bad() = throw(e)
4-
good() = 42
5-
6-
@test (@test_logs MLJTestIntegration.attempt(bad, "")) == (e, "×")
7-
@test(@test_logs(
8-
(:info, "look ×"),
9-
MLJTestIntegration.attempt(bad, "look "),
10-
) == (e, "×"))
11-
@test (@test_logs MLJTestIntegration.attempt(good, "")) == (42, "")
12-
@test (@test_logs(
13-
(:info, "look ✓"),
14-
MLJTestIntegration.attempt(good, "look "),
15-
) == (42, ""))
16-
@test_throws e MLJTestIntegration.attempt(bad, ""; throw=true)
17-
end
18-
19-
@testset "model_type" begin
1+
@testset "model_type with model proxy instead of type" begin
202

213
# test error thrown (not caught) if pkg missing from environment:
22-
@test_throws ArgumentError MLJTestIntegration.model_type(
4+
@test_throws ArgumentError MLJTestIntegration.MLJTestInterface.model_type(
235
(name="PCA", package_name="MultivariateStats"),
246
@__MODULE__
257
)
268

27-
M, outcome = MLJTestIntegration.model_type(
9+
M, outcome = MLJTestIntegration.MLJTestInterface.model_type(
2810
(name="DecisionTreeClassifier", package_name="DecisionTree"),
2911
@__MODULE__;
3012
verbosity=0

0 commit comments

Comments
 (0)