Skip to content

Commit bc0a61e

Browse files
authored
density-core (#96)
* density-core * include("density-core.jl") * fix indentation
1 parent ded46a0 commit bc0a61e

File tree

4 files changed

+167
-164
lines changed

4 files changed

+167
-164
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ include("latent-joint.jl")
149149
include("rand.jl")
150150

151151
include("density.jl")
152+
include("density-core.jl")
152153

153154
include("interface.jl")
154155

src/density-core.jl

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
2+
"""
3+
logdensityof(m::AbstractMeasure, x)
4+
5+
Compute the log-density of the measure `m` at `x`. Density is always relative,
6+
but `DensityInterface.jl` does not account for this. For compatibility with
7+
this, `logdensityof` for a measure is always implicitly relative to
8+
[`rootmeasure(x)`](@ref rootmeasure).
9+
10+
`logdensityof` works by first computing `insupport(m, x)`. If this is true, then
11+
`unsafe_logdensityof` is called. If `insupport(m, x)` is known to be `true`, it
12+
can be a little faster to directly call `unsafe_logdensityof(m, x)`.
13+
14+
To compute log-density relative to `basemeasure(m)` or *define* a log-density
15+
(relative to `basemeasure(m)` or another measure given explicitly), see
16+
`logdensity_def`.
17+
18+
To compute a log-density relative to a specific base-measure, see
19+
`logdensity_rel`.
20+
"""
21+
@inline function logdensityof::AbstractMeasure, x)
22+
result = dynamic(unsafe_logdensityof(μ, x))
23+
ifelse(insupport(μ, x) == true, result, oftype(result, -Inf))
24+
end
25+
26+
export unsafe_logdensityof
27+
28+
# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
29+
"""
30+
unsafe_logdensityof(m, x)
31+
32+
Compute the log-density of the measure `m` at `x` relative to `rootmeasure(m)`.
33+
This is "unsafe" because it does not check `insupport(m, x)`.
34+
35+
See also `logdensityof`.
36+
"""
37+
@inline function unsafe_logdensityof::M, x) where {M}
38+
ℓ_0 = logdensity_def(μ, x)
39+
b_0 = μ
40+
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
41+
b_{i} = basemeasure(b_{i - 1}, x)
42+
if b_{i} isa typeof(b_{i - 1})
43+
return ℓ_{i - 1}
44+
end
45+
ℓ_{i} = let Δℓ_{i} = logdensity_def(b_{i}, x)
46+
ℓ_{i - 1} + Δℓ_{i}
47+
end
48+
end
49+
return ℓ_10
50+
end
51+
52+
export density_rel
53+
54+
@inline density_rel(μ, ν, x) = exp(logdensity_rel(μ, ν, x))
55+
56+
export logdensity_rel
57+
58+
"""
59+
logdensity_rel(m1, m2, x)
60+
61+
Compute the log-density of `m1` relative to `m2` at `x`. This function checks
62+
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
63+
known to be in the support of both, it can be more efficient to call
64+
`unsafe_logdensity_rel`.
65+
"""
66+
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
67+
T = unstatic(
68+
promote_type(
69+
return_type(logdensity_def, (μ, x)),
70+
return_type(logdensity_def, (ν, x)),
71+
),
72+
)
73+
inμ = insupport(μ, x)
74+
inν = insupport(ν, x)
75+
inμ || return convert(T, ifelse(inν, -Inf, NaN))
76+
inν || return convert(T, Inf)
77+
78+
return unsafe_logdensity_rel(μ, ν, x)
79+
end
80+
81+
"""
82+
unsafe_logdensity_rel(m1, m2, x)
83+
84+
Compute the log-density of `m1` relative to `m2` at `x`, assuming `x` is
85+
known to be in the support of both `m1` and `m2`.
86+
87+
See also `logdensity_rel`.
88+
"""
89+
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
90+
if static_hasmethod(logdensity_def, Tuple{M,N,X})
91+
return logdensity_def(μ, ν, x)
92+
end
93+
μs = basemeasure_sequence(μ)
94+
νs = basemeasure_sequence(ν)
95+
cb = commonbase(μs, νs, X)
96+
# _logdensity_rel(μ, ν)
97+
isnothing(cb) && begin
98+
μ = μs[end]
99+
ν = νs[end]
100+
@warn """
101+
No common base measure for
102+
103+
and
104+
105+
106+
Returning a relative log-density of NaN. If this is incorrect, add a
107+
three-argument method
108+
logdensity_def(, , x)
109+
"""
110+
return NaN
111+
end
112+
return _logdensity_rel(μs, νs, cb, x)
113+
end
114+
115+
# Note that this method assumes `μ` and `ν` to have the same type
116+
function logdensity_def::T, ν::T, x) where {T}
117+
if μ === ν
118+
return zero(logdensity_def(μ, x))
119+
else
120+
return logdensity_def(μ, x) - logdensity_def(ν, x)
121+
end
122+
end
123+
124+
@generated function _logdensity_rel(
125+
μs::Tμ,
126+
νs::Tν,
127+
::Tuple{StaticInt{M},StaticInt{N}},
128+
x::X,
129+
) where {Tμ,Tν,M,N,X}
130+
= schema(Tμ)
131+
= schema(Tν)
132+
133+
q = quote
134+
$(Expr(:meta, :inline))
135+
= logdensity_def(μs[$M], νs[$N], x)
136+
end
137+
138+
for i in 1:M-1
139+
push!(q.args, :(Δℓ = logdensity_def(μs[$i], x)))
140+
# push!(q.args, :(println("Adding", Δℓ)))
141+
push!(q.args, :(ℓ += Δℓ))
142+
end
143+
144+
for j in 1:N-1
145+
push!(q.args, :(Δℓ = logdensity_def(νs[$j], x)))
146+
# push!(q.args, :(println("Subtracting", Δℓ)))
147+
push!(q.args, :(ℓ -= Δℓ))
148+
end
149+
150+
push!(q.args, :(return ℓ))
151+
return q
152+
end
153+
154+
export densityof
155+
export logdensityof
156+
157+
export density_def
158+
159+
density_def(μ, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x))
160+
density_def(μ, x) = exp(logdensity_def(μ, x))

src/density.jl

Lines changed: 0 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -100,170 +100,6 @@ Define a new measure in terms of a log-density `f` over some measure `base`.
100100
"""
101101
∫exp(f::Function, μ) = (logfuncdensity(f), μ)
102102

103-
"""
104-
logdensityof(m::AbstractMeasure, x)
105-
106-
Compute the log-density of the measure `m` at `x`. Density is always relative,
107-
but `DensityInterface.jl` does not account for this. For compatibility with
108-
this, `logdensityof` for a measure is always implicitly relative to
109-
[`rootmeasure(x)`](@ref rootmeasure).
110-
111-
`logdensityof` works by first computing `insupport(m, x)`. If this is true, then
112-
`unsafe_logdensityof` is called. If `insupport(m, x)` is known to be `true`, it
113-
can be a little faster to directly call `unsafe_logdensityof(m, x)`.
114-
115-
To compute log-density relative to `basemeasure(m)` or *define* a log-density
116-
(relative to `basemeasure(m)` or another measure given explicitly), see
117-
`logdensity_def`.
118-
119-
To compute a log-density relative to a specific base-measure, see
120-
`logdensity_rel`.
121-
"""
122-
@inline function logdensityof::AbstractMeasure, x)
123-
result = dynamic(unsafe_logdensityof(μ, x))
124-
ifelse(insupport(μ, x) == true, result, oftype(result, -Inf))
125-
end
126-
127-
export unsafe_logdensityof
128-
129-
# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
130-
"""
131-
unsafe_logdensityof(m, x)
132-
133-
Compute the log-density of the measure `m` at `x` relative to `rootmeasure(m)`.
134-
This is "unsafe" because it does not check `insupport(m, x)`.
135-
136-
See also `logdensityof`.
137-
"""
138-
@inline function unsafe_logdensityof::M, x) where {M}
139-
ℓ_0 = logdensity_def(μ, x)
140-
b_0 = μ
141-
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
142-
b_{i} = basemeasure(b_{i - 1}, x)
143-
if b_{i} isa typeof(b_{i - 1})
144-
return ℓ_{i - 1}
145-
end
146-
ℓ_{i} = let Δℓ_{i} = logdensity_def(b_{i}, x)
147-
ℓ_{i - 1} + Δℓ_{i}
148-
end
149-
end
150-
return ℓ_10
151-
end
152-
153-
export density_rel
154-
155-
@inline density_rel(μ, ν, x) = exp(logdensity_rel(μ, ν, x))
156-
157-
export logdensity_rel
158-
159-
@inline return_type(f, args::Tuple) = Core.Compiler.return_type(f, Tuple{typeof.(args)...})
160-
161-
unstatic(::Type{T}) where {T} = T
162-
unstatic(::Type{StaticFloat64{X}}) where {X} = Float64
163-
164-
"""
165-
logdensity_rel(m1, m2, x)
166-
167-
Compute the log-density of `m1` relative to `m2` at `x`. This function checks
168-
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
169-
known to be in the support of both, it can be more efficient to call
170-
`unsafe_logdensity_rel`.
171-
"""
172-
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
173-
T = unstatic(
174-
promote_type(
175-
return_type(logdensity_def, (μ, x)),
176-
return_type(logdensity_def, (ν, x)),
177-
),
178-
)
179-
inμ = insupport(μ, x)
180-
inν = insupport(ν, x)
181-
inμ || return convert(T, ifelse(inν, -Inf, NaN))
182-
inν || return convert(T, Inf)
183-
184-
return unsafe_logdensity_rel(μ, ν, x)
185-
end
186-
187-
"""
188-
unsafe_logdensity_rel(m1, m2, x)
189-
190-
Compute the log-density of `m1` relative to `m2` at `x`, assuming `x` is
191-
known to be in the support of both `m1` and `m2`.
192-
193-
See also `logdensity_rel`.
194-
"""
195-
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
196-
if static_hasmethod(logdensity_def, Tuple{M,N,X})
197-
return logdensity_def(μ, ν, x)
198-
end
199-
μs = basemeasure_sequence(μ)
200-
νs = basemeasure_sequence(ν)
201-
cb = commonbase(μs, νs, X)
202-
# _logdensity_rel(μ, ν)
203-
isnothing(cb) && begin
204-
μ = μs[end]
205-
ν = νs[end]
206-
@warn """
207-
No common base measure for
208-
209-
and
210-
211-
212-
Returning a relative log-density of NaN. If this is incorrect, add a
213-
three-argument method
214-
logdensity_def(, , x)
215-
"""
216-
return NaN
217-
end
218-
return _logdensity_rel(μs, νs, cb, x)
219-
end
220-
221-
# Note that this method assumes `μ` and `ν` to have the same type
222-
function logdensity_def::T, ν::T, x) where {T}
223-
if μ === ν
224-
return zero(logdensity_def(μ, x))
225-
else
226-
return logdensity_def(μ, x) - logdensity_def(ν, x)
227-
end
228-
end
229-
230-
@generated function _logdensity_rel(
231-
μs::Tμ,
232-
νs::Tν,
233-
::Tuple{StaticInt{M},StaticInt{N}},
234-
x::X,
235-
) where {Tμ,Tν,M,N,X}
236-
= schema(Tμ)
237-
= schema(Tν)
238-
239-
q = quote
240-
$(Expr(:meta, :inline))
241-
= logdensity_def(μs[$M], νs[$N], x)
242-
end
243-
244-
for i in 1:M-1
245-
push!(q.args, :(Δℓ = logdensity_def(μs[$i], x)))
246-
# push!(q.args, :(println("Adding", Δℓ)))
247-
push!(q.args, :(ℓ += Δℓ))
248-
end
249-
250-
for j in 1:N-1
251-
push!(q.args, :(Δℓ = logdensity_def(νs[$j], x)))
252-
# push!(q.args, :(println("Subtracting", Δℓ)))
253-
push!(q.args, :(ℓ -= Δℓ))
254-
end
255-
256-
push!(q.args, :(return ℓ))
257-
return q
258-
end
259-
260-
export densityof
261-
export logdensityof
262-
263-
export density_def
264-
265-
density_def(μ, ν::AbstractMeasure, x) = exp(logdensity_def(μ, ν, x))
266-
density_def(μ, x) = exp(logdensity_def(μ, x))
267103

268104
"""
269105
rebase(μ, ν)

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,9 @@ end
148148
function rmap(f, nt::NamedTuple{N,T}) where {N,T}
149149
NamedTuple{N}(map(x -> rmap(f, x), values(nt)))
150150
end
151+
152+
153+
@inline return_type(f, args::Tuple) = Core.Compiler.return_type(f, Tuple{typeof.(args)...})
154+
155+
unstatic(::Type{T}) where {T} = T
156+
unstatic(::Type{StaticFloat64{X}}) where {X} = Float64

0 commit comments

Comments
 (0)