Skip to content

Commit c5488c5

Browse files
authored
Merge pull request #9 from JuliaAI/measurements-tweak
Fix measurements method to always return `n` objects, where `n = numobs(y)`
2 parents e5488d2 + 09bd5c5 commit c5488c5

File tree

5 files changed

+14
-7
lines changed

5 files changed

+14
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StatisticalMeasuresBase"
22
uuid = "c062fc1d-0d66-479b-b6ac-8b44719de4cc"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/api.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ on data.
1111
1212
# New implementations
1313
14-
Overloading this function is optional. A fallback returns the aggregated measure, repeated
15-
`n` times, where `n = MLUtils.numobs(y)`. It is not typically necessary to overload
14+
Overloading this function for new measure types is optional. A fallback returns the
15+
aggregated measure, repeated `n` times, where `n = MLUtils.numobs(y)` (which falls back to
16+
`length(y)` if `numobs` is not implemented). It is not typically necessary to overload
1617
`measurements` for wrapped measures. All [`multimeasure`](@ref)s provide the obvious
1718
fallback and other wrappers simply forward the `measurements` method of the atomic
1819
measure. If overloading, use the following signatures:
@@ -24,7 +25,6 @@ measure. If overloading, use the following signatures:
2425
2526
"""
2627
function measurements(measure, yhat, y, args...)
27-
consumes_multiple_observations(measure) || return measure(yhat, y, args...)
2828
m = measure(yhat, y, args...)
2929
fill(m, MLUtils.numobs(y))
3030
end

src/traits.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ with the same number of observations as `y`.
9191
9292
# New implementations
9393
94-
Overloading the trait is optional and it is typically not overloaded. The general fallback
95-
returns `false` but it is `true` for any [`multimeasure`](@ref), and the value is
94+
Overload this trait for a new measure type that consumes multiple observations, unless it
95+
has been constructed using `multimeaure` or is an $API.jl wrap thereof. The general
96+
fallback returns `false` but it is `true` for any [`multimeasure`](@ref), and the value is
9697
propagated by other wrappers.
9798
9899
"""

test/api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ w = 10:10:10N
99
measurements(measure, 2.3, 3.4) == LPLossOnScalars()(2.3, 3.4)
1010
measure = MeanAbsoluteError()
1111
measurements(measure, yhat, y) == fill(measure(yhat, y), N)
12+
13+
# if a measure does not overload `consumes_multiple_observations` but really does, then
14+
# `meausurements` should nevertheless have the expected behavior:
15+
badly_implemented_rms(yhat, y) = (yhat - y).^2 |> mean |> sqrt
16+
μ = badly_implemented_rms([4, 5], [1, 1])
17+
@test measurements(badly_implemented_rms, [4, 5], [1, 1]) [μ, μ]
1218
end
1319

1420
true

test/measure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ API.is_measure(::Measure{typeof(min)}) = true
99
@testset "calling" begin
1010
measure = Measure(min)
1111
@test measure(2, -3) == 3
12-
@test measurements(measure, 2, -3) == 3
12+
@test measurements(measure, 2, -3) == [3,]
1313
end
1414

1515
true

0 commit comments

Comments
 (0)