Skip to content

Commit db50c6d

Browse files
committed
force measurements to always a vector
1 parent e5488d2 commit db50c6d

File tree

4 files changed

+13
-6
lines changed

4 files changed

+13
-6
lines changed

src/api.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ on data.
1111
1212
# New implementations
1313
14-
Overloading this function is optional. A fallback returns the aggregated measure, repeated
15-
`n` times, where `n = MLUtils.numobs(y)`. It is not typically necessary to overload
14+
Overloading this function for new measure types is optional. A fallback returns the
15+
aggregated measure, repeated `n` times, where `n = MLUtils.numobs(y)` (which falls back to
16+
`length(y)` if `numobs` is not implemented). It is not typically necessary to overload
1617
`measurements` for wrapped measures. All [`multimeasure`](@ref)s provide the obvious
1718
fallback and other wrappers simply forward the `measurements` method of the atomic
1819
measure. If overloading, use the following signatures:
@@ -24,7 +25,6 @@ measure. If overloading, use the following signatures:
2425
2526
"""
2627
function measurements(measure, yhat, y, args...)
27-
consumes_multiple_observations(measure) || return measure(yhat, y, args...)
2828
m = measure(yhat, y, args...)
2929
fill(m, MLUtils.numobs(y))
3030
end

src/traits.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ with the same number of observations as `y`.
9191
9292
# New implementations
9393
94-
Overloading the trait is optional and it is typically not overloaded. The general fallback
95-
returns `false` but it is `true` for any [`multimeasure`](@ref), and the value is
94+
Overload this trait for a new measure type that consumes multiple observations, unless it
95+
has been constructed using `multimeaure` or is an $API.jl wrap thereof. The general
96+
fallback returns `false` but it is `true` for any [`multimeasure`](@ref), and the value is
9697
propagated by other wrappers.
9798
9899
"""

test/api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ w = 10:10:10N
99
measurements(measure, 2.3, 3.4) == LPLossOnScalars()(2.3, 3.4)
1010
measure = MeanAbsoluteError()
1111
measurements(measure, yhat, y) == fill(measure(yhat, y), N)
12+
13+
# if a measure does not overload `consumes_multiple_observations` but really does, then
14+
# `meausurements` should nevertheless have the expected behavior:
15+
badly_implemented_rms(yhat, y) = (yhat - y).^2 |> mean |> sqrt
16+
μ = badly_implemented_rms([4, 5], [1, 1])
17+
@test measurements(badly_implemented_rms, [4, 5], [1, 1]) [μ, μ]
1218
end
1319

1420
true

test/measure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ API.is_measure(::Measure{typeof(min)}) = true
99
@testset "calling" begin
1010
measure = Measure(min)
1111
@test measure(2, -3) == 3
12-
@test measurements(measure, 2, -3) == 3
12+
@test measurements(measure, 2, -3) == [3,]
1313
end
1414

1515
true

0 commit comments

Comments
 (0)