Skip to content

Commit 9add372

Browse files
committed
Specialize logdensityof for DensityMeasure
Ensures proper type propagation (until future refactor of density calculation engine).
1 parent 1aa6984 commit 9add372

File tree

5 files changed

+63
-7
lines changed

5 files changed

+63
-7
lines changed

ext/MeasureBaseChainRulesCoreExt.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ using ChainRulesCore: NoTangent, ZeroTangent
77
import ChainRulesCore
88

99

10+
# = utils ====================================================================
11+
12+
using MeasureBase: isneginf, isposinf
13+
14+
_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
15+
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback
16+
17+
_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
18+
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback
19+
20+
1021
# = insupport & friends ======================================================
1122

1223
using MeasureBase:
@@ -44,4 +55,12 @@ _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
4455
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
4556

4657

58+
# = return type inference ====================================================
59+
60+
using MeasureBase: logdensityof_rt
61+
62+
_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
63+
ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v) = logdensityof_rt(target, v), _logdensityof_rt_pullback
64+
65+
4766
end # module MeasureBaseChainRulesCoreExt

src/density-core.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ To compute a log-density relative to a specific base-measure, see
3333
_checksupport(insupport(μ, x), result)
3434
end
3535

36+
@inline function logdensityof_rt(::T, ::U) where {T,U}
37+
Core.Compiler.return_type(logdensityof, Tuple{T,U})
38+
end
39+
3640
_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
3741

3842

src/density.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,25 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)
163163

164164
density_def::DensityMeasure, x) = densityof.f, x)
165165

166+
function logdensityof::DensityMeasure, x::Any)
167+
integrand, μ_base = μ.f, μ.base
168+
169+
base_logval = logdensityof(μ_base, x)
170+
171+
T = typeof(base_logval)
172+
U = logdensityof_rt(integrand, x)
173+
R = promote_type(T, U)
174+
175+
# Don't evaluate base measure if integrand is zero or NaN
176+
if isneginf(base_logval)
177+
R(-Inf)
178+
else
179+
integrand_logval = logdensityof(integrand, x)
180+
convert(R, integrand_logval + base_logval)::R
181+
end
182+
end
183+
184+
166185
"""
167186
rebase(μ, ν)
168187

src/utils.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,18 @@ using InverseFunctions: FunctionWithInverse
165165
unwrap(f) = f
166166
unwrap(f::FunctionWithInverse) = f.f
167167

168-
169168
fcomp(f, g) = fchain(g, f)
170169
fcomp(::typeof(identity), g) = g
171170
fcomp(f, ::typeof(identity)) = f
172171
fcomp(::typeof(identity), ::typeof(identity)) = identity
173172

173+
near_neg_inf(::Type{T}) where {T<:Real} = T(-1E38) # Still fits into Float32
174174

175-
near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32
176-
177-
isneginf(x) = isinf(x) && x < 0
178-
isposinf(x) = isinf(x) && x > 0
175+
isneginf(x) = isinf(x) && x < zero(x)
176+
isposinf(x) = isinf(x) && x > zero(x)
179177

180-
isapproxzero(x::T) where T<:Real = x zero(T)
178+
isapproxzero(x::T) where {T<:Real} = x zero(T)
181179
isapproxzero(A::AbstractArray) = all(isapproxzero, A)
182180

183-
isapproxone(x::T) where T<:Real = x one(T)
181+
isapproxone(x::T) where {T<:Real} = x one(T)
184182
isapproxone(A::AbstractArray) = all(isapproxone, A)

test/test_basics.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,22 @@ end
189189
end
190190
end
191191

192+
@testset "logdensityof" begin
193+
f1 = let A=randn(Float32, 3,3); x -> sum(A*x); end
194+
f2 = x -> sqrt(abs(sum(x)))
195+
f3 = x -> 2 * sum(x)
196+
f4 = x -> sum(sqrt.(abs.(x)))
197+
m = @inferred ∫exp(f1, ∫exp(f2, ∫exp(f3, ∫exp(f4, StdUniform()^3))))
198+
199+
for x in [
200+
Float32[0.7, 0.2, 0.5],
201+
Float32[-0.7, 0.2, 0.5],
202+
]
203+
@test @inferred(logdensityof(m, x)) isa Float32
204+
@test logdensityof(m, x) f1(x) + f2(x) + f3(x) + f4(x) + logdensityof(StdUniform()^3, x)
205+
end
206+
end
207+
192208
@testset "logdensity_rel" begin
193209
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
194210
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf

0 commit comments

Comments
 (0)