@@ -118,11 +118,10 @@ end
118
118
# ###############
119
119
# Divergences
120
120
121
+ # Kullback-Leibler
121
122
# # StatsBase handling of pᵢ = qᵢ = 0
122
- # _xlogxdy(x::T, y::T) where {T} = _xlogy(x, ifelse(iszero(x) & iszero(y), zero(T), x / y))
123
123
# vkldivergence(p, q; dims=:) = vvmapreduce(_xlogxdy, +, p, q, dims=dims)
124
124
# Slightly more efficient (and likely more stable)
125
- _klterm (x:: T , y:: T ) where {T} = _xlogy (x, x) - _xlogy (x, y)
126
125
vkldivergence (p, q; dims= :) = vvmapreduce (_klterm, + , p, q, dims= dims)
127
126
vkldivergence (p, q, b:: Real ; dims= :) = (c = 1 / log (b); vmapreducethen (_klterm, + , x -> x * c, p, q, dims= dims))
128
127
vtkldivergence (p, q; dims= :) = vtmapreduce (_klterm, + , p, q, dims= dims)
@@ -132,6 +131,30 @@ vtkldivergence(p, q, b::Real; dims=:) = (c = 1 / log(b); vtmapreducethen(_klterm
132
131
vgkldiv (x, y; dims= :) = vvmapreduce ((xᵢ, yᵢ) -> xᵢ * (log (xᵢ) - log (yᵢ)) - xᵢ + yᵢ, + , x, y, dims= dims)
133
132
vtgkldiv (x, y; dims= :) = vtmapreduce ((xᵢ, yᵢ) -> xᵢ * (log (xᵢ) - log (yᵢ)) - xᵢ + yᵢ, + , x, y, dims= dims)
134
133
134
+ # Rénya
135
+ _vrenyadivergence (p, q, α:: Real , dims) =
136
+ vmapreducethen ((pᵢ, qᵢ) -> pᵢ^ α / qᵢ^ (α- 1 ), + , x -> (1 / (α- 1 )) * log (x), p, q, dims= dims)
137
+
138
+ function vrenyadivergence (p, q, α:: Real ; dims= :)
139
+ if α ≈ 0
140
+ vmapreducethen ((pᵢ, qᵢ) -> ifelse (pᵢ > zero (pᵢ), qᵢ, zero (qᵢ)), + , x -> - log (x), p, q, dims= dims)
141
+ elseif α ≈ 0.5
142
+ vmapreducethen ((pᵢ, qᵢ) -> √ (pᵢ * qᵢ), + , x -> - 2 log (x), p, q, dims= dims)
143
+ elseif α ≈ 1
144
+ vkldivergence (p, q, dims= dims)
145
+ elseif α ≈ 2
146
+ c = log (_denom (p, dims))
147
+ vmapreducethen (/ , + , x -> log (x) - c, p, q, dims= dims)
148
+ elseif isinf (α)
149
+ vmapreducethen (/ , max, log, p, q, dims= dims)
150
+ else
151
+ _vrenyadivergence (p, q, α, dims)
152
+ end
153
+ end
154
+
155
+
156
+
157
+
135
158
# ###############
136
159
# Deviations
137
160
@@ -174,3 +197,5 @@ const vmsd = vmse
174
197
const vtmse = vtmsd
175
198
const vrmsd = vrmse
176
199
const vtrmsd = vtrmse
200
+
201
+ # ###############
0 commit comments