Skip to content

Commit 07f7ca6

Browse files
authored
Merge pull request #985 from JuliaAI/relax-resampling-type
Relax restrictions on model type in resampling (`evaluate!`)
2 parents 9d78c6c + 64b9481 commit 07f7ca6

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

src/resampling.jl

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError(
3131
_ambiguous_operation(model, measure) =
3232
"`$measure` does not support a `model` with "*
3333
"`prediction_type(model) == :$(prediction_type(model))`. "
34-
err_ambiguous_operation(model, measure) = ArgumentError(
35-
_ambiguous_operation(model, measure)*
36-
"\nUnable to infer an appropriate operation for `$measure`. "*
37-
"Explicitly specify `operation=...` or `operations=...`. ")
3834
err_incompatible_prediction_types(model, measure) = ArgumentError(
3935
_ambiguous_operation(model, measure)*
4036
"If your model is truly making probabilistic predictions, try explicitly "*
@@ -65,11 +61,37 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError(
6561
"and so is not supported by `$measure`. "*LOG_AVOID
6662
)
6763

68-
# ==================================================================
69-
## MODEL TYPES THAT CAN BE EVALUATED
64+
err_ambiguous_operation(model, measure) = ArgumentError(
65+
_ambiguous_operation(model, measure)*
66+
"\nUnable to infer an appropriate operation for `$measure`. "*
67+
"Explicitly specify `operation=...` or `operations=...`. "*
68+
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
69+
)
70+
71+
const ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
72+
"""
7073
71-
# not exported:
72-
const Measurable = Union{Supervised, Annotator}
74+
The `prediction_type` of your model needs to be one of: `:deterministic`,
75+
`:probabilistic`, or `:interval`. Does your model implement one of these operations:
76+
$PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...`
77+
or `operations=...` (and consider posting an issue to have the model review it's
78+
definition of `MLJModelInterface.prediction_type`). Otherwise, performance
79+
evaluation is not supported.
80+
81+
"""
82+
)
83+
84+
const ERR_NEED_TARGET = ArgumentError(
85+
"""
86+
87+
To evaluate a model's performance you must provide a target variable `y`, as in
88+
`evaluate(model, X, y; options...)` or
89+
90+
mach = machine(model, X, y)
91+
evaluate!(mach; options...)
92+
93+
"""
94+
)
7395

7496
# ==================================================================
7597
## RESAMPLING STRATEGIES
@@ -987,7 +1009,7 @@ function _actual_operations(operation::Nothing,
9871009
throw(err_ambiguous_operation(model, m))
9881010
end
9891011
else
990-
throw(err_ambiguous_operation(model, m))
1012+
throw(ERR_UNSUPPORTED_PREDICTION_TYPE)
9911013
end
9921014
end
9931015
end
@@ -1137,7 +1159,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
11371159
11381160
"""
11391161
function evaluate!(
1140-
mach::Machine{<:Measurable};
1162+
mach::Machine;
11411163
resampling=CV(),
11421164
measures=nothing,
11431165
measure=measures,
@@ -1160,6 +1182,8 @@ function evaluate!(
11601182
# weights, measures, operations, and dispatches a
11611183
# strategy-specific `evaluate!`
11621184

1185+
length(mach.args) > 1 || throw(ERR_NEED_TARGET)
1186+
11631187
repeats > 0 || error("Need `repeats > 0`. ")
11641188

11651189
if resampling isa TrainTestPairs
@@ -1235,7 +1259,7 @@ Returns a [`PerformanceEvaluation`](@ref) object.
12351259
See also [`evaluate!`](@ref).
12361260
12371261
"""
1238-
evaluate(model::Measurable, args...; cache=true, kwargs...) =
1262+
evaluate(model::Model, args...; cache=true, kwargs...) =
12391263
evaluate!(machine(model, args...; cache=cache); kwargs...)
12401264

12411265
# -------------------------------------------------------------------

test/resampling.jl

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ end
2525
struct DummyInterval <: Interval end
2626
dummy_interval=DummyInterval()
2727

28+
struct GoofyTransformer <: Unsupervised end
29+
2830
dummy_measure_det(yhat, y) = 42
2931
API.@trait(
3032
typeof(dummy_measure_det),
@@ -115,6 +117,12 @@ API.@trait(
115117
MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()),
116118
MLJBase._actual_operations(nothing,
117119
[LogLoss(), ], dummy_interval, 1))
120+
121+
# model does not have a valid `prediction_type`:
122+
@test_throws(
123+
MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE,
124+
MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),
125+
)
118126
end
119127

120128
@everywhere begin
@@ -935,7 +943,29 @@ end
935943
end
936944
end
937945

938-
# DUMMY LOGGER
946+
947+
# # TRANSFORMER WITH PREDICT
948+
949+
struct PredictingTransformer <:Unsupervised end
950+
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
951+
MLJBase.fit(::PredictingTransformer, verbosity, X) = (nothing, nothing, nothing)
952+
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
953+
MLJBase.predict(::PredictingTransformer, ::Nothing, X) = nothing
954+
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic
955+
956+
@testset "`Unsupervised` model with a predict" begin
957+
X = rand(10)
958+
y = fill(42.0, 10)
959+
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
960+
@test e.measurement[1] 0
961+
@test_throws(
962+
MLJBase.ERR_NEED_TARGET,
963+
evaluate(PredictingTransformer(), X, measure=l2),
964+
)
965+
end
966+
967+
968+
# # DUMMY LOGGER
939969

940970
struct DummyLogger end
941971

0 commit comments

Comments
 (0)