Skip to content

Commit 7e5b111

Browse files
Rénya divergence
1 parent 34176af commit 7e5b111

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

src/vstats.jl

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,10 @@ end
118118
################
119119
# Divergences
120120

121+
# Kullback-Leibler
121122
# # StatsBase handling of pᵢ = qᵢ = 0
122-
# _xlogxdy(x::T, y::T) where {T} = _xlogy(x, ifelse(iszero(x) & iszero(y), zero(T), x / y))
123123
# vkldivergence(p, q; dims=:) = vvmapreduce(_xlogxdy, +, p, q, dims=dims)
124124
# Slightly more efficient (and likely more stable)
125-
_klterm(x::T, y::T) where {T} = _xlogy(x, x) - _xlogy(x, y)
126125
vkldivergence(p, q; dims=:) = vvmapreduce(_klterm, +, p, q, dims=dims)
127126
vkldivergence(p, q, b::Real; dims=:) = (c = 1 / log(b); vmapreducethen(_klterm, +, x -> x * c, p, q, dims=dims))
128127
vtkldivergence(p, q; dims=:) = vtmapreduce(_klterm, +, p, q, dims=dims)
@@ -132,6 +131,30 @@ vtkldivergence(p, q, b::Real; dims=:) = (c = 1 / log(b); vtmapreducethen(_klterm
132131
vgkldiv(x, y; dims=:) = vvmapreduce((xᵢ, yᵢ) -> xᵢ * (log(xᵢ) - log(yᵢ)) - xᵢ + yᵢ, +, x, y, dims=dims)
133132
vtgkldiv(x, y; dims=:) = vtmapreduce((xᵢ, yᵢ) -> xᵢ * (log(xᵢ) - log(yᵢ)) - xᵢ + yᵢ, +, x, y, dims=dims)
134133

134+
# Rénya
135+
_vrenyadivergence(p, q, α::Real, dims) =
136+
vmapreducethen((pᵢ, qᵢ) -> pᵢ^α / qᵢ^-1), +, x -> (1/-1)) * log(x), p, q, dims=dims)
137+
138+
function vrenyadivergence(p, q, α::Real; dims=:)
139+
if α 0
140+
vmapreducethen((pᵢ, qᵢ) -> ifelse(pᵢ > zero(pᵢ), qᵢ, zero(qᵢ)), +, x -> -log(x), p, q, dims=dims)
141+
elseif α 0.5
142+
vmapreducethen((pᵢ, qᵢ) -> (pᵢ * qᵢ), +, x -> -2log(x), p, q, dims=dims)
143+
elseif α 1
144+
vkldivergence(p, q, dims=dims)
145+
elseif α 2
146+
c = log(_denom(p, dims))
147+
vmapreducethen(/, +, x -> log(x) - c, p, q, dims=dims)
148+
elseif isinf(α)
149+
vmapreducethen(/, max, log, p, q, dims=dims)
150+
else
151+
_vrenyadivergence(p, q, α, dims)
152+
end
153+
end
154+
155+
156+
157+
135158
################
136159
# Deviations
137160

@@ -174,3 +197,5 @@ const vmsd = vmse
174197
const vtmse = vtmsd
175198
const vrmsd = vrmse
176199
const vtrmsd = vtrmse
200+
201+
################

0 commit comments

Comments
 (0)