|
179 | 179 | function Distributions.Categorical(p::TrackedVector; check_args = true)
|
180 | 180 | return TuringDiscreteNonParametric(1:length(p), p, check_args = check_args)
|
181 | 181 | end
|
| 182 | + |
| 183 | +## Dirichlet ## |
| 184 | + |
| 185 | +struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution |
| 186 | + alpha::TV |
| 187 | + alpha0::T |
| 188 | + lmnB::T |
| 189 | +end |
| 190 | +function check(alpha) |
| 191 | + all(ai -> ai > 0, alpha) || |
| 192 | + throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) |
| 193 | +end |
| 194 | +Zygote.@nograd DistributionsAD.check |
| 195 | + |
| 196 | +function TuringDirichlet(alpha::AbstractVector) |
| 197 | + check(alpha) |
| 198 | + alpha0 = sum(alpha) |
| 199 | + lmnB = sum(loggamma, alpha) - loggamma(alpha0) |
| 200 | + T = promote_type(typeof(alpha0), typeof(lmnB)) |
| 201 | + TV = typeof(alpha) |
| 202 | + TuringDirichlet{T, TV}(alpha, alpha0, lmnB) |
| 203 | +end |
| 204 | + |
| 205 | +function TuringDirichlet(d::Integer, alpha::Real) |
| 206 | + alpha0 = alpha * d |
| 207 | + _alpha = fill(alpha, d) |
| 208 | + lmnB = loggamma(alpha) * d - loggamma(alpha0) |
| 209 | + T = promote_type(typeof(alpha0), typeof(lmnB)) |
| 210 | + TV = typeof(_alpha) |
| 211 | + TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) |
| 212 | +end |
| 213 | +function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} |
| 214 | + Tf = float(T) |
| 215 | + TuringDirichlet(convert(AbstractVector{Tf}, alpha)) |
| 216 | +end |
| 217 | +TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) |
| 218 | + |
| 219 | +Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) |
| 220 | +Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) |
| 221 | + |
| 222 | +function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector) |
| 223 | + simplex_logpdf(d.alpha, d.lmnB, x) |
| 224 | +end |
| 225 | +function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix) |
| 226 | + simplex_logpdf(d.alpha, d.lmnB, x) |
| 227 | +end |
| 228 | +function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T} |
| 229 | + TV = typeof(d.alpha) |
| 230 | + logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x) |
| 231 | +end |
| 232 | + |
| 233 | +ZygoteRules.@adjoint function Distributions.Dirichlet(alpha) |
| 234 | + return pullback(TuringDirichlet, alpha) |
| 235 | +end |
| 236 | +ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) |
| 237 | + return pullback(TuringDirichlet, d, alpha) |
| 238 | +end |
| 239 | + |
| 240 | +function simplex_logpdf(alpha, lmnB, x::AbstractVector) |
| 241 | + sum((alpha .- 1) .* log.(x)) - lmnB |
| 242 | +end |
| 243 | +function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) |
| 244 | + @views init = vcat(sum((alpha .- 1) .* log.(x[:,1]))) |
| 245 | + mapreduce(vcat, drop(eachcol(x), 1); init = init) do c |
| 246 | + sum((alpha .- 1) .* log.(c)) - lmnB |
| 247 | + end |
| 248 | +end |
| 249 | + |
| 250 | +Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) |
| 251 | + simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin |
| 252 | + (Δ .* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1)) |
| 253 | + end |
| 254 | +end |
| 255 | +Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) |
| 256 | + simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin |
| 257 | + (log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) |
| 258 | + end |
| 259 | +end |
| 260 | + |
| 261 | +ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) |
| 262 | + simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1)) |
| 263 | +end |
| 264 | + |
| 265 | +ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) |
| 266 | + simplex_logpdf(alpha, lmnB, x), Δ -> begin |
| 267 | + (log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ)) |
| 268 | + end |
| 269 | +end |
| 270 | + |
182 | 271 | ## MvNormal ##
|
183 | 272 |
|
184 | 273 | """
|
|
249 | 338 | function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
|
250 | 339 | return -((size(x, 1) * log(2π) + 2 * sum(log.(d.σ))) .+ vec(sum(abs2.((x .- d.m) ./ d.σ), dims=1))) ./ 2
|
251 | 340 | end
|
| 341 | + |
252 | 342 | function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
|
253 | 343 | return -(length(x) * log(2π) + logdet(d.C) + sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)))) / 2
|
254 | 344 | end
|
|
0 commit comments