Skip to content

Commit dfb3701

Browse files
committed
distributions boilerplate
1 parent 1055f30 commit dfb3701

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

src/probprog/Distributions.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
abstract type Distribution end
2+
3+
sampler(::Type{D}) where {D<:Distribution} = error("sampler not implemented for $D")
4+
logpdf_fn(::Type{D}) where {D<:Distribution} = error("logpdf_fn not implemented for $D")
5+
params(d::Distribution) = error("params not implemented for $(typeof(d))")
6+
7+
struct Normal{Tμ,Tσ,S<:Tuple} <: Distribution
8+
μ::Tμ
9+
σ::Tσ
10+
shape::S
11+
12+
function Normal{Tμ,Tσ,S}(μ::Tμ, σ::Tσ, shape::S) where {Tμ,Tσ,S<:Tuple}
13+
isempty(shape) && throw(ArgumentError("shape cannot be empty"))
14+
return new{Tμ,Tσ,S}(μ, σ, shape)
15+
end
16+
end
17+
18+
Normal(μ::Tμ, σ::Tσ, shape::S) where {Tμ,Tσ,S<:Tuple} = Normal{Tμ,Tσ,S}(μ, σ, shape)
19+
Normal() = Normal(0.0, 1.0, (1,))
20+
Normal(μ, σ) = Normal(μ, σ, (1,))
21+
22+
params(d::Normal) = (d.μ, d.σ, d.shape)
23+
24+
function _normal_sampler(rng, μ, σ, shape)
25+
return μ .+ σ .* randn(rng, shape)
26+
end
27+
28+
function _normal_logpdf(x, μ, σ, _)
29+
z = (x .- μ) ./ σ
30+
n = length(x)
31+
return -n * log(σ) - n / 2 * log(2π) - sum(z .^ 2) / 2
32+
end
33+
34+
sampler(::Type{<:Normal}) = _normal_sampler
35+
logpdf_fn(::Type{<:Normal}) = _normal_logpdf

src/probprog/Modeling.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ function sample(
5656
return traced_result
5757
end
5858

59+
function sample(
60+
rng::AbstractRNG, dist::D; symbol::Symbol=gensym("sample")
61+
) where {D<:Distribution}
62+
return sample(rng, sampler(D), params(dist)...; symbol=symbol, logpdf=logpdf_fn(D))
63+
end
64+
5965
function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
6066
args_with_rng = (rng, args...)
6167

src/probprog/ProbProg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ..Reactant:
55
using ..Compiler: @jit, @compile
66

77
include("Types.jl")
8+
include("Distributions.jl")
89
include("FFI.jl")
910
include("Modeling.jl")
1011
include("Display.jl")
@@ -14,6 +15,9 @@ include("MCMC.jl")
1415
# Types.
1516
export ProbProgTrace, Constraint, Selection, Address
1617

18+
# Distributions.
19+
export Distribution, Normal
20+
1721
# Utility functions.
1822
export get_choices, select
1923

0 commit comments

Comments
 (0)