Skip to content

Commit a764053

Browse files
committed
allow tuples in resampling option & fix error message
1 parent d8a8946 commit a764053

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

src/resampling.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
# TYPE ALIASES
1+
# # LOCAL TYPE ALIASES
22

33
const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
4-
const TrainTestPair = Tuple{AbstractRow,AbstractRow}
5-
const TrainTestPairs = AbstractVector{<:TrainTestPair}
4+
const TrainTestPair = Tuple{AbstractRow, AbstractRow}
5+
const TrainTestPairs = Union{
6+
NTuple{<:Any,TrainTestPair},
7+
AbstractVector{<:TrainTestPair},
8+
}
69

710

811
# # ERROR MESSAGES
@@ -93,6 +96,13 @@ const ERR_NEED_TARGET = ArgumentError(
9396
"""
9497
)
9598

99+
const ERR_BAD_RESAMPLING_OPTION = ArgumentError(
100+
"`resampling` must be an "*
101+
"`MLJ.ResamplingStrategy` or a vector (or tuple) of tuples "*
102+
"of the form `(train_rows, test_rows)`"
103+
)
104+
105+
96106
# ==================================================================
97107
## RESAMPLING STRATEGIES
98108

@@ -1402,10 +1412,6 @@ end
14021412
# ------------------------------------------------------------
14031413
# Core `evaluation` method, operating on train-test pairs
14041414

1405-
const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
1406-
const TrainTestPair = Tuple{AbstractRow, AbstractRow}
1407-
const TrainTestPairs = AbstractVector{<:TrainTestPair}
1408-
14091415
_view(::Nothing, rows) = nothing
14101416
_view(weights, rows) = view(weights, rows)
14111417

@@ -1434,11 +1440,7 @@ function evaluate!(
14341440
# Note: `rows` and `repeats` are only passed to the final `PeformanceEvaluation`
14351441
# object to be returned and are not otherwise used here.
14361442

1437-
if !(resampling isa TrainTestPairs)
1438-
error("`resampling` must be an "*
1439-
"`MLJ.ResamplingStrategy` or tuple of rows "*
1440-
"of the form `(train_rows, test_rows)`")
1441-
end
1443+
resampling isa TrainTestPairs || throw(ERR_BAD_RESAMPLING_OPTION)
14421444

14431445
X = mach.args[1]()
14441446
y = mach.args[2]()

test/resampling.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,19 @@ end
276276
model = DeterministicConstantRegressor()
277277
mach = machine(model, X, y, cache=cache)
278278

279+
# check catch for bad `resampling` option:
280+
@test_throws(
281+
MLJBase.ERR_BAD_RESAMPLING_OPTION,
282+
evaluate(model, X, y; resampling="junk", verbosity=0, acceleration=accel),
283+
)
284+
285+
# check we can provide tuples of pairs, instead of vectors of pairs, in
286+
# `resampling` option:
287+
e1 = evaluate(model, X, y; resampling, verbosity=0, acceleration=accel)
288+
e2 = evaluate(model, X, y;
289+
resampling=Tuple(resampling), verbosity=0, acceleration=accel)
290+
@test e1.measurement[1] e2.measurement[1]
291+
279292
# check detection of incompatible measure (cross_entropy):
280293
@test_throws(
281294
MLJBase.err_incompatible_prediction_types(model, cross_entropy),

0 commit comments

Comments
 (0)