Skip to content

Commit 2209563

Browse files
committed
relax restrictions on model type in resampling
1 parent 370b3da commit 2209563

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/resampling.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ const ERR_INVALID_OPERATION = ArgumentError(
3131
_ambiguous_operation(model, measure) =
3232
"`$measure` does not support a `model` with "*
3333
"`prediction_type(model) == :$(prediction_type(model))`. "
34-
err_ambiguous_operation(model, measure) = ArgumentError(
35-
_ambiguous_operation(model, measure)*
36-
"\nUnable to infer an appropriate operation for `$measure`. "*
37-
"Explicitly specify `operation=...` or `operations=...`. ")
3834
err_incompatible_prediction_types(model, measure) = ArgumentError(
3935
_ambiguous_operation(model, measure)*
4036
"If your model is truly making probabilistic predictions, try explicitly "*
@@ -65,11 +61,25 @@ ERR_MEASURES_DETERMINISTIC(measure) = ArgumentError(
6561
"and so is not supported by `$measure`. "*LOG_AVOID
6662
)
6763

68-
# ==================================================================
69-
## MODEL TYPES THAT CAN BE EVALUATED
64+
err_ambiguous_operation(model, measure) = ArgumentError(
65+
_ambiguous_operation(model, measure)*
66+
"\nUnable to infer an appropriate operation for `$measure`. "*
67+
"Explicitly specify `operation=...` or `operations=...`. "*
68+
"Possible value(s) are: $PREDICT_OPERATIONS_STRING. "
69+
)
70+
71+
ERR_UNSUPPORTED_PREDICTION_TYPE = ArgumentError(
72+
"""
73+
74+
The `prediction_type` of your model needs to be one of: `:deterministic`,
75+
`:probabilistic`, or `:interval`. Does your model implement one of these operations:
76+
$PREDICT_OPERATIONS_STRING? If so, you can try explicitly specifying `operation=...`
77+
or `operations=...` (and consider posting an issue to have the model review it's
78+
definition of `MLJModelInterface.prediction_type`). Otherwise, performance
79+
evaluation is not supported.
7080
71-
# not exported:
72-
const Measurable = Union{Supervised, Annotator}
81+
"""
82+
)
7383

7484
# ==================================================================
7585
## RESAMPLING STRATEGIES
@@ -987,7 +997,7 @@ function _actual_operations(operation::Nothing,
987997
throw(err_ambiguous_operation(model, m))
988998
end
989999
else
990-
throw(err_ambiguous_operation(model, m))
1000+
throw(ERR_UNSUPPORTED_PREDICTION_TYPE)
9911001
end
9921002
end
9931003
end
@@ -1137,7 +1147,7 @@ See also [`evaluate`](@ref), [`PerformanceEvaluation`](@ref),
11371147
11381148
"""
11391149
function evaluate!(
1140-
mach::Machine{<:Measurable};
1150+
mach::Machine;
11411151
resampling=CV(),
11421152
measures=nothing,
11431153
measure=measures,
@@ -1235,7 +1245,7 @@ Returns a [`PerformanceEvaluation`](@ref) object.
12351245
See also [`evaluate!`](@ref).
12361246
12371247
"""
1238-
evaluate(model::Measurable, args...; cache=true, kwargs...) =
1248+
evaluate(model::Model, args...; cache=true, kwargs...) =
12391249
evaluate!(machine(model, args...; cache=cache); kwargs...)
12401250

12411251
# -------------------------------------------------------------------

test/resampling.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ end
2525
struct DummyInterval <: Interval end
2626
dummy_interval=DummyInterval()
2727

28+
struct GoofyTransformer <: Unsupervised end
29+
2830
dummy_measure_det(yhat, y) = 42
2931
API.@trait(
3032
typeof(dummy_measure_det),
@@ -115,6 +117,12 @@ API.@trait(
115117
MLJBase.err_ambiguous_operation(dummy_interval, LogLoss()),
116118
MLJBase._actual_operations(nothing,
117119
[LogLoss(), ], dummy_interval, 1))
120+
121+
# model not have a valid `prediction_type`:
122+
@test_throws(
123+
MLJBase.ERR_UNSUPPORTED_PREDICTION_TYPE,
124+
MLJBase._actual_operations(nothing, [LogLoss(),], GoofyTransformer(), 0),
125+
)
118126
end
119127

120128
@everywhere begin
@@ -935,7 +943,23 @@ end
935943
end
936944
end
937945

938-
# DUMMY LOGGER
946+
947+
# # TRANSFORMER WITH PREDICT
948+
949+
struct PredictingTransformer <:Unsupervised end
950+
MLJBase.fit(::PredictingTransformer, verbosity, X, y) = (mean(y), nothing, nothing)
951+
MLJBase.predict(::PredictingTransformer, fitresult, X) = fill(fitresult, nrows(X))
952+
MLJBase.prediction_type(::Type{<:PredictingTransformer}) = :deterministic
953+
954+
@testset "`Unsupervised` model with a predict" begin
955+
X = rand(10)
956+
y = fill(42.0, 10)
957+
e = evaluate(PredictingTransformer(), X, y, resampling=Holdout(), measure=l2)
958+
@test e.measurement[1] 0
959+
end
960+
961+
962+
# # DUMMY LOGGER
939963

940964
struct DummyLogger end
941965

0 commit comments

Comments
 (0)