Skip to content

Commit 13dc516

Browse files
authored
Merge pull request #1010 from JuliaAI/resampling-option
Allow tuples in the `resampling` option for `evaluate`
2 parents 88f2243 + f5fcca0 commit 13dc516

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJBase"
22
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "1.8.2"
4+
version = "1.9.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/resampling.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
# TYPE ALIASES
1+
# # LOCAL TYPE ALIASES
22

33
const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
4-
const TrainTestPair = Tuple{AbstractRow,AbstractRow}
5-
const TrainTestPairs = AbstractVector{<:TrainTestPair}
4+
const TrainTestPair = Tuple{AbstractRow, AbstractRow}
5+
const TrainTestPairs = Union{
6+
Tuple{Vararg{TrainTestPair}},
7+
AbstractVector{<:TrainTestPair},
8+
}
69

710

811
# # ERROR MESSAGES
@@ -93,6 +96,19 @@ const ERR_NEED_TARGET = ArgumentError(
9396
"""
9497
)
9598

99+
const ERR_BAD_RESAMPLING_OPTION = ArgumentError(
100+
"`resampling` must be an "*
101+
"`MLJ.ResamplingStrategy` or a vector (or tuple) of tuples "*
102+
"of the form `(train_rows, test_rows)`"
103+
)
104+
105+
const ERR_EMPTY_RESAMPLING_OPTION = ArgumentError(
106+
"`resampling` cannot be emtpy. It must be an "*
107+
"`MLJ.ResamplingStrategy` or a vector (or tuple) of tuples "*
108+
"of the form `(train_rows, test_rows)`"
109+
)
110+
111+
96112
# ==================================================================
97113
## RESAMPLING STRATEGIES
98114

@@ -1402,10 +1418,6 @@ end
14021418
# ------------------------------------------------------------
14031419
# Core `evaluation` method, operating on train-test pairs
14041420

1405-
const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
1406-
const TrainTestPair = Tuple{AbstractRow, AbstractRow}
1407-
const TrainTestPairs = AbstractVector{<:TrainTestPair}
1408-
14091421
_view(::Nothing, rows) = nothing
14101422
_view(weights, rows) = view(weights, rows)
14111423

@@ -1434,11 +1446,8 @@ function evaluate!(
14341446
# Note: `rows` and `repeats` are only passed to the final `PeformanceEvaluation`
14351447
# object to be returned and are not otherwise used here.
14361448

1437-
if !(resampling isa TrainTestPairs)
1438-
error("`resampling` must be an "*
1439-
"`MLJ.ResamplingStrategy` or an vector tuples "*
1440-
"of the form `(train_rows, test_rows)`")
1441-
end
1449+
isempty(resampling) && throw(ERR_EMPTY_RESAMPLING_OPTION)
1450+
resampling isa TrainTestPairs || throw(ERR_BAD_RESAMPLING_OPTION)
14421451

14431452
X = mach.args[1]()
14441453
y = mach.args[2]()

test/resampling.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,23 @@ end
276276
model = DeterministicConstantRegressor()
277277
mach = machine(model, X, y, cache=cache)
278278

279+
# check catch for bad `resampling` options:
280+
@test_throws(
281+
MLJBase.ERR_BAD_RESAMPLING_OPTION,
282+
evaluate(model, X, y; resampling="junk", verbosity=0, acceleration=accel),
283+
)
284+
@test_throws(
285+
MLJBase.ERR_EMPTY_RESAMPLING_OPTION,
286+
evaluate(model, X, y; resampling=[], verbosity=0, acceleration=accel),
287+
)
288+
289+
# check we can provide tuples of pairs, instead of vectors of pairs, in
290+
# `resampling` option:
291+
e1 = evaluate(model, X, y; resampling, verbosity=0, acceleration=accel)
292+
e2 = evaluate(model, X, y;
293+
resampling=Tuple(resampling), verbosity=0, acceleration=accel)
294+
@test e1.measurement[1] e2.measurement[1]
295+
279296
# check detection of incompatible measure (cross_entropy):
280297
@test_throws(
281298
MLJBase.err_incompatible_prediction_types(model, cross_entropy),

0 commit comments

Comments
 (0)