Skip to content

Commit 0890e24

Browse files
authored
Fix exponential (#164)
* Make _logpdf use Real (for autodiff friendliness) * exponential updates * bump version
1 parent 8ddd2a5 commit 0890e24

File tree

4 files changed

+83
-7
lines changed

4 files changed

+83
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.11.4"
4+
version = "0.12.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/MeasureTheory.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ function logpdf(d::AbstractMeasure, x)
7676
_logpdf(d, x, logdensity(d,x))
7777
end
7878

79-
function _logpdf(d::AbstractMeasure, x, acc::Float64)
79+
function _logpdf(d::AbstractMeasure, x, acc::Real)
8080
β = basemeasure(d)
8181
d === β && return acc
8282

@@ -104,6 +104,7 @@ include("combinators/weighted.jl")
104104
include("combinators/product.jl")
105105
include("combinators/transforms.jl")
106106
include("combinators/chain.jl")
107+
# include("combinators/basemeasure.jl")
107108

108109
include("distributions.jl")
109110

src/parameterized/exponential.jl

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
export Exponential
55

6-
@parameterized Exponential(λ) Lebesgue(ℝ₊)
6+
@parameterized Exponential(β) Lebesgue(ℝ₊)
7+
78

8-
@kwstruct Exponential(λ)
99
@kwstruct Exponential()
1010

1111
function logdensity(d::Exponential{()} , x)
@@ -14,21 +14,82 @@ end
1414

1515
Base.rand(rng::AbstractRNG, T::Type, μ::Exponential{()}) = randexp(rng,T)
1616

17+
TV.as(::Exponential) = asℝ₊
18+
19+
20+
##########################
21+
# Scale β
22+
23+
@kwstruct Exponential(β)
24+
25+
function Base.rand(rng::AbstractRNG, T::Type, d::Exponential{(:β,)})
26+
randexp(rng, T) * d.β
27+
end
28+
29+
function logdensity(d::Exponential{(:β,)}, x)
30+
z = x / d.β
31+
return logdensity(Exponential(), z) - log(d.β)
32+
end
33+
34+
distproxy(d::Exponential{(:β,)}) = Dists.Exponential(d.β)
35+
36+
asparams(::Type{<:Exponential}, ::Val{:β}) = asℝ₊
37+
38+
##########################
39+
# Log-Scale logβ
40+
41+
42+
@kwstruct Exponential(logβ)
43+
44+
function Base.rand(rng::AbstractRNG, T::Type, d::Exponential{(:logβ,)})
45+
randexp(rng, T) * exp(d.logβ)
46+
end
47+
48+
function logdensity(d::Exponential{(:logβ,)}, x)
49+
z = x * exp(-d.logβ)
50+
return logdensity(Exponential(), z) - d.logβ
51+
end
52+
53+
distproxy(d::Exponential{(:logβ,)}) = Dists.Exponential(exp(d.logβ))
54+
55+
asparams(::Type{<:Exponential}, ::Val{:logβ}) = asℝ
56+
57+
58+
1759

1860
##########################
61+
# Rate λ
62+
63+
@kwstruct Exponential(λ)
1964

2065
function Base.rand(rng::AbstractRNG, T::Type, d::Exponential{(:λ,)})
2166
randexp(rng, T) / d.λ
2267
end
2368

24-
TV.as(::Exponential) = asℝ₊
25-
2669
function logdensity(d::Exponential{(:λ,)}, x)
2770
z = x * d.λ
2871
return logdensity(Exponential(), z) + log(d.λ)
2972
end
3073

31-
distproxy(d::Exponential{(:λ,)}) = Dists.Exponential(d.λ)
74+
distproxy(d::Exponential{(:λ,)}) = Dists.Exponential(1/d.λ)
3275

3376
asparams(::Type{<:Exponential}, ::Val{:λ}) = asℝ₊
77+
78+
##########################
79+
# Log-Rate logλ
80+
81+
82+
@kwstruct Exponential(logλ)
83+
84+
function Base.rand(rng::AbstractRNG, T::Type, d::Exponential{(:logλ,)})
85+
randexp(rng, T) * exp(-d.logλ)
86+
end
87+
88+
function logdensity(d::Exponential{(:logλ,)}, x)
89+
z = x * exp(d.logλ)
90+
return logdensity(Exponential(), z) + d.logλ
91+
end
92+
93+
distproxy(d::Exponential{(:logλ,)}) = Dists.Exponential(exp(-d.logλ))
94+
3495
asparams(::Type{<:Exponential}, ::Val{:logλ}) = asℝ

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,20 @@ end
9494
@test_broken logdensity(Binomial(n,p), CountingMeasure(ℤ[0:n]), x) binomlogpdf(n,p,x)
9595
end
9696

97+
@testset "Exponential" begin
98+
r = rand(MersenneTwister(123), Exponential(2))
99+
@test r rand(MersenneTwister(123), Exponential=2))
100+
@test r rand(MersenneTwister(123), Exponential=0.5))
101+
@test r rand(MersenneTwister(123), Exponential(logβ=log(2)))
102+
@test r rand(MersenneTwister(123), Exponential(logλ=log(0.5)))
103+
104+
= logdensity(Exponential(2), r)
105+
@test logdensity(Exponential=2), r)
106+
@test logdensity(Exponential=0.5), r)
107+
@test logdensity(Exponential(logβ=log(2)), r)
108+
@test logdensity(Exponential(logλ=log(0.5)), r)
109+
end
110+
97111
@testset "NegativeBinomial" begin
98112
D = NegativeBinomial{(:r, :p)}
99113
par = transform(asparams(D), randn(2))

0 commit comments

Comments
 (0)