Skip to content

Commit 46b676e

Browse files
authored
Dev (#15)
* update 3-arg logdensity * update logjac * factoredbase * update power measure * update half.jl to use Factoredbase * drop outdated test * add a test * bump version
1 parent 1ba2986 commit 46b676e

File tree

8 files changed

+70
-20
lines changed

8 files changed

+70
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.3.7"
4+
version = "0.4.0"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ include("primitives/lebesgue.jl")
4949
include("primitives/dirac.jl")
5050
include("primitives/trivial.jl")
5151

52+
include("combinators/factoredbase.jl")
5253
include("combinators/weighted.jl")
5354
include("combinators/affine.jl")
5455
include("combinators/superpose.jl")

src/combinators/affine.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ Base.propertynames(d::AffineTransform{N}) where {N} = N
2222
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
2323
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
2424

25+
# TODO: `log` doesn't work for the multivariate case, we need the log absolute determinant
26+
logjac(f::AffineTransform{(:μ,:σ)}) = log(f.σ)
27+
logjac(f::AffineTransform{(:μ,:ω)}) = -log(f.ω)
28+
logjac(f::AffineTransform{(:σ,)}) = log(f.σ)
29+
logjac(f::AffineTransform{(:ω,)}) = -log(f.ω)
30+
logjac(f::AffineTransform{(:μ,)}) = 0.0
31+
2532
###############################################################################
2633

2734
struct Affine{N,M,T} <: AbstractMeasure
@@ -61,8 +68,6 @@ Base.propertynames(d::Affine{N}) where {N} = N ∪ (:parent,)
6168
end
6269
end
6370

64-
65-
6671
# Note: We could also write
6772
# logdensity(d::Affine, x) = logdensity(inv(getfield(d, :f)), x)
6873

@@ -74,16 +79,12 @@ logdensity(d::Affine{(:μ,)}, x) = logdensity(d.parent, x - d.μ)
7479

7580
basemeasure(d::Affine) = Affine(getfield(d, :f), basemeasure(d.parent))
7681

77-
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = d.parent
82+
# We can't do this until we know we're working with Lebesgue measure, since for
83+
# example it wouldn't make sense to apply a log-Jacobian to a point measure
84+
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = WeightedMeasure(-logjac(d), d.parent)
7885

79-
logdensity(d::Affine{N,L}, x) where {N,L<:Lebesgue} = logjac(getfield(d, :f))
86+
logjac(d::Affine) = logjac(getfield(d, :f))
8087

81-
# TODO: `log` doesn't work for the multivariate case, we need the log absolute determinant
82-
logjac(::AffineTransform{(:μ,:σ)}) = -log(d.σ)
83-
logjac(::AffineTransform{(:μ,:ω)}) = log(d.ω)
84-
logjac(::AffineTransform{(:σ,)}) = -log(d.σ)
85-
logjac(::AffineTransform{(:ω,)}) = log(d.ω)
86-
logjac(::AffineTransform{(:μ,)}) = 0.0
8788

8889
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
8990
z = rand(rng, T, parent(d))

src/combinators/factoredbase.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export FactoredBase
2+
3+
struct FactoredBase{R,C,V,B} <: AbstractMeasure
4+
inbounds :: R
5+
const:: C
6+
varℓ :: V
7+
base :: B
8+
end
9+
10+
function logdensity(d::FactoredBase, x)
11+
d.inbounds(x) || return -Inf
12+
d.const+ d.varℓ
13+
end
14+
15+
basemeasure(d::FactoredBase) = d.base

src/combinators/half.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,16 @@ end
1111

1212
unhalf::Half) = μ.parent
1313

14-
basemeasure::Half) = WeightedMeasure(logtwo, basemeasure(unhalf(μ)))
14+
function basemeasure::Half)
15+
inbounds(x) = x > 0
16+
const= logtwo
17+
varℓ = 0.0
18+
base = basemeasure(unhalf(μ))
19+
FactoredBase(inbounds, constℓ, varℓ, base)
20+
end
1521

1622
function Base.rand(rng::AbstractRNG, T::Type, μ::Half)
1723
return abs(rand(rng, T, unhalf(μ)))
1824
end
1925

20-
logdensity::Half, x) = x > 0 ? logdensity(unhalf(μ), x) : -Inf
26+
logdensity::Half, x) = logdensity(unhalf(μ), x)

src/combinators/power.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,21 @@ params(d::ProductMeasure{F,<:Fill}) where {F} = params(first(marginals(d)))
6363
params(::Type{P}) where {F,P<:ProductMeasure{F,<:Fill}} = params(D)
6464

6565
# basemeasure(μ::PowerMeasure) = @inbounds basemeasure(first(μ.data))^size(μ.data)
66+
67+
@inline basemeasure(d::PowerMeasure) = _basemeasure(d, (basemeasure(d.f(first(d.pars)))))
68+
69+
@inline _basemeasure(d::PowerMeasure, b) = b ^ size(d.pars)
70+
71+
@inline function _basemeasure(d::PowerMeasure, b::FactoredBase)
72+
n = length(d.pars)
73+
inbounds(x) = all(xj -> b.inbounds(xj), x)
74+
const= n * b.const
75+
varℓ = n * b.varℓ
76+
base = b.base ^ size(d.pars)
77+
FactoredBase(inbounds, constℓ, varℓ, base)
78+
end
79+
80+
function logdensity(d::PowerMeasure, x)
81+
d1 = d.f(first(d.pars))
82+
sum(xj -> logdensity(d1, xj), x)
83+
end

src/density.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ function logdensity(μ::AbstractMeasure, ν::AbstractMeasure, x)
107107
α = basemeasure(μ)
108108
β = basemeasure(ν)
109109

110+
# If α===μ and β===ν, The recursive call would be exactly the same as the
111+
# original one. We need to break the recursion.
110112
if α===μ && β===ν
111113
@warn """
112114
No method found for logdensity(μ, ν, x) where
@@ -119,10 +121,18 @@ function logdensity(μ::AbstractMeasure, ν::AbstractMeasure, x)
119121
return NaN
120122
end
121123

122-
ℓμ = logdensity(μ, x)
123-
ℓν = logdensity(ν, x)
124+
# Infinite or NaN results occur when outside the support of α or β,
125+
# and also when one measure is singular wrt the other. Computing the base
126+
# measures first is often much cheaper, and allows the numerically-intensive
127+
# computation to "fall through" in these cases.
128+
# TODO: Add tests to check that NaN cases work properly
129+
= logdensity(α, β, x)
130+
isfinite(ℓ) || return
124131

125-
return ℓμ - ℓν + logdensity(α, β, x)
132+
+= logdensity(μ, x)
133+
-= logdensity(ν, x)
134+
135+
return
126136
end
127137

128138
logdensity(::Lebesgue, ::Lebesgue, x) = 0.0

test/runtests.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,12 @@ end
154154
@testset "Half" begin
155155
Normal() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ))
156156
@half Normal
157-
@test logdensity(HalfNormal(), -0.2) == -Inf
157+
@test logdensity(HalfNormal(), Lebesgue(ℝ), -0.2) == -Inf
158158
@test logdensity(HalfNormal(), 0.2) == logdensity(Normal(), 0.2)
159-
160-
@half Lebesgue
161-
@test basemeasure(HalfLebesgue(ℝ)) == 2 * Lebesgue(ℝ)
159+
@test density(HalfNormal(), Lebesgue(ℝ), 0.2) 2 * density(Normal(), Lebesgue(ℝ), 0.2)
162160
end
163161

162+
164163
# import MeasureBase.:⋅
165164
# function ⋅(μ::Normal, kernel)
166165
# m = kernel(μ)

0 commit comments

Comments
 (0)