|
| 1 | +abstract type Distribution end |
| 2 | + |
| 3 | +sampler(::Type{D}) where {D<:Distribution} = error("sampler not implemented for $D") |
| 4 | +logpdf_fn(::Type{D}) where {D<:Distribution} = error("logpdf_fn not implemented for $D") |
| 5 | +params(d::Distribution) = error("params not implemented for $(typeof(d))") |
| 6 | + |
| 7 | +struct Normal{Tμ,Tσ,S<:Tuple} <: Distribution |
| 8 | + μ::Tμ |
| 9 | + σ::Tσ |
| 10 | + shape::S |
| 11 | + |
| 12 | + function Normal{Tμ,Tσ,S}(μ::Tμ, σ::Tσ, shape::S) where {Tμ,Tσ,S<:Tuple} |
| 13 | + isempty(shape) && throw(ArgumentError("shape cannot be empty")) |
| 14 | + return new{Tμ,Tσ,S}(μ, σ, shape) |
| 15 | + end |
| 16 | +end |
| 17 | + |
| 18 | +Normal(μ::Tμ, σ::Tσ, shape::S) where {Tμ,Tσ,S<:Tuple} = Normal{Tμ,Tσ,S}(μ, σ, shape) |
| 19 | +Normal() = Normal(0.0, 1.0, (1,)) |
| 20 | +Normal(μ, σ) = Normal(μ, σ, (1,)) |
| 21 | + |
| 22 | +params(d::Normal) = (d.μ, d.σ, d.shape) |
| 23 | + |
| 24 | +function _normal_sampler(rng, μ, σ, shape) |
| 25 | + return μ .+ σ .* randn(rng, shape) |
| 26 | +end |
| 27 | + |
| 28 | +function _normal_logpdf(x, μ, σ, _) |
| 29 | + z = (x .- μ) ./ σ |
| 30 | + n = length(x) |
| 31 | + return -n * log(σ) - n / 2 * log(2π) - sum(z .^ 2) / 2 |
| 32 | +end |
| 33 | + |
| 34 | +sampler(::Type{<:Normal}) = _normal_sampler |
| 35 | +logpdf_fn(::Type{<:Normal}) = _normal_logpdf |
0 commit comments