Skip to content

Commit 150f107

Browse files
Means, entropies
1 parent b5adba8 commit 150f107

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

src/vstats.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,108 @@ vtmean(A; dims=:) = vtmean(identity, A, dims=dims)
3838
# Naturally, faster than the overflow/underflow-safe logsumexp, but if one can tolerate it...
3939
vlse(A; dims=:) = vmapreducethen(exp, +, log, A, dims=dims)
4040
vtlse(A; dims=:) = vtmapreducethen(exp, +, log, A, dims=dims)
41+
42+
# Assorted functions from StatsBase
43+
function vgeomean(A; dims=:)
44+
c = 1 / _denom(A, dims)
45+
vmapreducethen(log, +, x -> exp(c * x), A, dims=dims)
46+
end
47+
function vgeomean(f::F, A; dims=:) where {F}
48+
c = 1 / _denom(A, dims)
49+
vmapreducethen(x -> log(f(x)), +, x -> exp(c * x), A, dims=dims)
50+
end
51+
52+
function vharmmean(A; dims=:)
53+
c = 1 / _denom(A, dims)
54+
vmapreducethen(inv, +, x -> inv(c * x), A, dims=dims)
55+
end
56+
57+
_xlogx(x::T) where {T} = ifelse(iszero(x), zero(T), x * log(x))
58+
_xlogy(x::T, y::T) where {T} = ifelse(iszero(x) & !isnan(y), zero(T), x * log(y))
59+
60+
ventropy(A; dims=:) = vmapreducethen(_xlogx, +, -, A, dims=dims)
61+
ventropy(A, b::Real; dims=:) = (c = -1 / log(b); vmapreducethen(_xlogx, +, x -> x * c, A, dims=dims ))
62+
63+
vcrossentropy(p, q; dims=:) = vmapreducethen(_xlogy, +, -, p, q, dims=dims)
64+
vcrossentropy(p, q, b::Real; dims=:) = (c = -1 / log(b); vmapreducethen(_xlogy, +, x -> x * c, p, q, dims=dims))
65+
66+
# max-, Shannon, collision and min- entropy assume that p ∈ ℝⁿ, pᵢ ≥ 0, ∑pᵢ=1
67+
_vmaxentropy(p, dims::NTuple{M, Int}) where {M} =
68+
fill!(similar(p, ntuple(d -> d dims ? 1 : size(p, d), Val(M))), log(_denom(p, dims)))
69+
_vmaxentropy(p, ::Colon) = log(length(p))
70+
vmaxentropy(p; dims=:) = _vmaxentropy(p, dims)
71+
vshannonentropy(p; dims=:) = vmapreducethen(_xlogx, +, -, p, dims=dims)
72+
vcollisionentropy(p; dims=:) = vmapreducethen(abs2, +, x -> -log(x), p, dims=dims)
73+
vminentropy(p; dims=:) = vmapreducethen(identity, max, x -> -log(x), p, dims=dims)
74+
75+
_vrenyientropy(p, α::T, dims) where {T<:Integer} =
76+
(c = one(T) / (one(T) - α); vmapreducethen(x -> x^α, +, x -> c * log(x), p, dims=dims))
77+
_vrenyientropy(p, α::T, dims) where {T<:AbstractFloat} =
78+
(c = one(T) / (one(T) - α); vmapreducethen(x -> exp* log(x)), +, x -> c * log(x), p, dims=dims))
79+
_vrenyientropy(p, α::Rational{T}, dims) where {T} = _vrenyientropy(p, float(α), dims)
80+
function vrenyientropy(p, α::Real; dims=:)
81+
if α 0
82+
vmaxentropy(p, dims=dims)
83+
elseif α 1
84+
vshannonentropy(p, dims=dims)
85+
elseif α 2
86+
vcollisionentropy(p, dims=dims)
87+
elseif isinf(α)
88+
vminentropy(p, dims=dims)
89+
else
90+
_vrenyientropy(p, α, dims)
91+
end
92+
end
93+
# Loosened restrictions: p ∈ ℝⁿ, pᵢ ≥ 0, ∑pᵢ > 1; that is, if one normalized p, a valid
94+
# probability vector would be produced. Thus, H(x, α) = (α/(1-α)) * (1/α * log∑xᵢ^α - log∑xᵢ)
95+
# H(x, α) = (α / (1 - α)) * ((1/α) * log(sum(z -> z^α, x)) - log(sum(x)))
96+
vrenyientropynorm(p, α::Real; dims=:) =
97+
vrenyientropy(p, α, dims=dims) .-/(1-α)) .* log.(vnorm(p, 1, dims=dims))
98+
99+
vrenyientropy(x2n, 1.5)
100+
renyientropy(x2n, 1.5)
101+
102+
den = sum(abs2, x2)
103+
sum(abs2.(x2)./ den)
104+
sum(abs2, x2 ./ den)
105+
106+
(abs2(1 / sum(abs, x2)) * sum(abs2, x2))
107+
(1 / sum(abs, x2)) * norm(x2)
108+
norm(x2n)
109+
norm(x2) / norm(x2, 1)
110+
log(norm(x2))
111+
log(norm(x2)) - log(norm(x2, 1))
112+
113+
114+
# # StatsBase handling of pᵢ = qᵢ = 0
115+
# _xlogxdy(x::T, y::T) where {T} = _xlogy(x, ifelse(iszero(x) & iszero(y), zero(T), x / y))
116+
# vkldivergence(p, q; dims=:) = vvmapreduce(_xlogxdy, +, p, q, dims=dims)
117+
# Slightly more efficient (and likely more stable)
118+
_klterm(x::T, y::T) where {T} = _xlogy(x, x) - _xlogy(x, y)
119+
vkldivergence(p, q; dims=:) = vvmapreduce(_klterm, +, p, q, dims=dims)
120+
vkldivergence(p, q, b::Real; dims=:) = (c = 1 / log(b); vmapreducethen(_klterm, +, x -> x * c, p, q, dims=dims))
121+
122+
123+
124+
125+
vcounteq(x, y; dims=:) = vvmapreduce(==, +, x, y, dims=dims)
126+
vtcounteq(x, y; dims=:) = vtmapreduce(==, +, x, y, dims=dims)
127+
vcountne(x, y; dims=:) = vvmapreduce(!=, +, x, y, dims=dims)
128+
vtcountne(x, y; dims=:) = vtmapreduce(!=, +, x, y, dims=dims)
129+
130+
function vmeanad(x, y; dims=:)
131+
c = 1 / _denom(x, dims)
132+
vmapreducethen((xᵢ, yᵢ) -> abs(xᵢ - yᵢ) , +, z -> c * z, x, y, dims=dims)
133+
end
134+
function vtmeanad(x, y; dims=:)
135+
c = 1 / _denom(x, dims)
136+
vtmapreducethen((xᵢ, yᵢ) -> abs(xᵢ - yᵢ) , +, z -> c * z, x, y, dims=dims)
137+
end
138+
139+
vmaxad(x, y; dims=:) = vvmapreduce((xᵢ, yᵢ) -> abs(xᵢ - yᵢ) , max, x, y, dims=dims)
140+
vtmaxad(x, y; dims=:) = vtmapreduce((xᵢ, yᵢ) -> abs(xᵢ - yᵢ) , max, x, y, dims=dims)
141+
142+
143+
# generalized KL divergence sum(a*log(a/b)-a+b)
144+
vgkldiv(x, y; dims=:) = vvmapreduce((xᵢ, yᵢ) -> xᵢ * (log(xᵢ) - log(yᵢ)) - xᵢ + yᵢ, +, x, y, dims=dims)
145+
vtgkldiv(x, y; dims=:) = vtmapreduce((xᵢ, yᵢ) -> xᵢ * (log(xᵢ) - log(yᵢ)) - xᵢ + yᵢ, +, x, y, dims=dims)

0 commit comments

Comments
 (0)