diff --git a/Project.toml b/Project.toml index 9ac1de4..e7c992c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "StatisticalMeasures" uuid = "a19d573c-0a75-4610-95b3-7071388c7541" authors = ["Anthony D. Blaom "] -version = "0.3.0" +version = "0.3.1" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/continuous.jl b/src/continuous.jl index 6be51ff..4286dd2 100644 --- a/src/continuous.jl +++ b/src/continuous.jl @@ -545,3 +545,55 @@ RSquared const rsq = RSquared() "$RSquaredDoc" const rsquared = rsq + +# ------------------------------------------------------------------------- +# Willmott index of agreement (d) + +# type for measure without argument checks: +struct _WillmottD end + +function (::_WillmottD)(yhat, y) + μ = aggregate(y) # mean + # numerator: Σ_i (ŷ_i - y_i)^2 + num = LPSumLoss(p=2)(yhat, y) + # denominator: Σ_i (|ŷ_i - μ| + |y_i - μ|)^2 + den = multimeasure((yhat, y) -> (abs(yhat - μ) + abs(y - μ))^2; mode=Sum())(yhat, y) + return den == 0 ? (num == 0 ? 1.0 : 0.0) : 1 - num/den +end + +WillmottD() = _WillmottD() |> API.robust_measure |> API.fussy_measure +const WillmottDType = API.FussyMeasure{<:API.RobustMeasure{<:_WillmottD}} + +@trait( + _WillmottD, + consumes_multiple_observations = true, + kind_of_proxy = LearnAPI.Point(), + observation_scitype = Union{Missing,Infinite}, + orientation = Score(), + human_name = "Willmott index of agreement (d)", +) + +@fix_show WillmottD::WillmottDType + +register(WillmottD, "willmott_d") + +const WillmottDDoc = docstring( + "WillmottD()", + scitype=DOC_INFINITE, + body= +""" +Returns Willmott index of agreement (d) + +``d = 1 - \\dfrac{\\sum (ŷ_i - y_i)^2}{\\sum (|ŷ_i - \\bar y| + |y_i - \\bar y|)^2}``, + +where ``\\bar y`` is the mean of the targets. The value lies in ``[0,1]`` with higher +being better. + +References: Willmott [(1981)](https://doi.org/10.1080/02723646.1981.10642213) +""", +) + +"$WillmottDDoc" +WillmottD +"$WillmottDDoc" +const willmott_d = WillmottD() diff --git a/test/continuous.jl b/test/continuous.jl index 226bd87..dbbc3e8 100644 --- a/test/continuous.jl +++ b/test/continuous.jl @@ -31,6 +31,16 @@ rng = srng(666899) yhat = rand(rng, 4) @test isapprox(log_cosh(yhat, y), mean(log.(cosh.(yhat - y)))) @test rsq(yhat, y) == 1 - sum((yhat - y).^2)/sum((y .- mean(y)).^2) + let + num = sum((yhat - y).^2) + den = sum((abs.(yhat .- mean(y)) .+ abs.(y .- mean(y))).^2) + @test isapprox(willmott_d(yhat, y), den == 0 ? (num == 0 ? 1.0 : 0.0) : 1 - num/den) + # additional tests for willmott_d + @test willmott_d(yhat, yhat) == 1 + @test willmott_d(y, y) == 1 + @test willmott_d(yhat .+ 1, zeros(length(yhat))) == 0 # yhat .+ 1 ensures it's not all zeros + @test willmott_d(zeros(4), zeros(4)) == 1 + end # a multi-target test where there is a parameter: y = rand(rng, 2, 10)