Skip to content

Commit 85fde05

Browse files
committed
Introduce strict_logdensityof
1 parent d65c1eb commit 85fde05

17 files changed

+80
-40
lines changed

ext/MeasureBaseChainRulesCoreExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,16 @@ ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checke
4444

4545
# = return type inference ====================================================
4646

47-
using MeasureBase: logdensityof_rt
47+
using MeasureBase: logdensityof_rt, strict_logdensityof_rt
4848

4949
_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
5050
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
5151
logdensityof_rt(target, v), _logdensityof_rt_pullback
5252
end
5353

54+
_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
55+
function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v)
56+
strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback
57+
end
58+
5459
end # module MeasureBaseChainRulesCoreExt

src/combinators/half.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
1919
return abs(rand(rng, T, unhalf(μ)))
2020
end
2121

22-
function logdensityof::Half, x)
23-
ld = logdensityof(unhalf(μ), x) - loghalf
22+
function strict_logdensityof::Half, x)
23+
ld = strict_logdensityof(unhalf(μ), x) - loghalf
2424
return x 0 ? ld : oftype(ld, -Inf)
2525
end
2626

src/combinators/power.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ params(d::PowerMeasure) = params(first(marginals(d)))
7878
basemeasure(d.parent)^d.axes
7979
end
8080

81-
for func in [:logdensityof, :logdensity_def]
81+
for func in [:strict_logdensityof, :logdensity_def]
8282
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
8383
parent = d.parent
8484
sum(x) do xj

src/combinators/product.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function _rand_product(
7272
end |> collect
7373
end
7474

75-
for func in [:logdensityof, :logdensity_def]
75+
for func in [:strict_logdensityof, :logdensity_def]
7676
@eval @inline function $func(d::AbstractProductMeasure, x)
7777
mapreduce($func, +, marginals(d), x)
7878
end
@@ -82,7 +82,7 @@ struct ProductMeasure{M} <: AbstractProductMeasure
8282
marginals::M
8383
end
8484

85-
@inline function logdensity_rel::ProductMeasure, ν::ProductMeasure, x)
85+
@inline function strict_logdensity_rel::ProductMeasure, ν::ProductMeasure, x)
8686
mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x)
8787
end
8888

@@ -109,7 +109,7 @@ end
109109
return q
110110
end
111111

112-
for func in [:logdensityof, :logdensity_def]
112+
for func in [:strict_logdensityof, :logdensity_def]
113113
# For tuples, `mapreduce` has trouble with type inference
114114
@eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple}
115115
ℓs = map($func, marginals(d), x)

src/combinators/spikemixture.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
SpikeMixture(basemeasure.m), static(1.0), static(1.0))
2222
end
2323

24-
for func in [:logdensityof, :logdensity_def]
24+
for func in [:strict_logdensityof, :logdensity_def]
2525
@eval @inline function $func::SpikeMixture, x)
2626
# NOTE: We could instead write this as
2727
# R1 = typeof(log(one(μ.s)))

src/combinators/transformedmeasure.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,15 @@ function Pretty.tile(ν::PushforwardMeasure)
103103
end
104104

105105
# TODO: THIS IS ALMOST CERTAINLY WRONG
106-
# @inline function logdensity_rel(
106+
# @inline function strict_logdensity_rel(
107107
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
108108
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
109109
# y,
110110
# ) where {FF1,IF1,M1,FF2,IF2,M2}
111111
# x = β.inv_f(y)
112112
# f = ν.inv_f ∘ β.f
113113
# inv_f = β.inv_f ∘ ν.f
114-
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
114+
# strict_logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
115115
# end
116116

117117
# TODO: Would profit from custom pullback:
@@ -132,7 +132,7 @@ function _combine_logd_with_ladj(logd_orig::Real, ladj::Real)
132132
end
133133
end
134134

135-
function logdensityof(
135+
function strict_logdensityof(
136136
@nospecialize::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
137137
@nospecialize(v::Any)
138138
) where {M}
@@ -143,7 +143,7 @@ function logdensityof(
143143
)
144144
end
145145

146-
function logdensityof(
146+
function strict_logdensityof(
147147
@nospecialize::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
148148
@nospecialize(v::Any)
149149
) where {M}
@@ -154,7 +154,7 @@ function logdensityof(
154154
)
155155
end
156156

157-
for func in [:logdensityof, :logdensity_def]
157+
for func in [:strict_logdensityof, :logdensity_def]
158158
@eval function $func::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
159159
f_inv = unwrap.finv)
160160
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)

src/density-core.jl

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ To compute log-density relative to `basemeasure(m)` or *define* a log-density
2626
`logdensity_def`.
2727
2828
To compute a log-density relative to a specific base-measure, see
29-
`logdensity_rel`.
29+
`logdensity_rel`.
30+
31+
# Implementation
32+
33+
Do not specialize `logdensityof` directly for subtypes of `AbstractMeasure`,
34+
specialize `MeasureBase.logdensity_def` and `MeasureBase.strict_logdensityof` instead.
3035
"""
31-
@inline function logdensityof::AbstractMeasure, x)
32-
result = dynamic(unsafe_logdensityof(μ, x))
33-
_checksupport(insupport(μ, x), result)
36+
@inline function logdensityof::AbstractMeasure, x) #!!!!!!!!!!!!!!!!!
37+
strict_logdensityof(μ, x)
3438
end
3539

3640
@inline function logdensityof_rt(::T, ::U) where {T,U}
@@ -41,6 +45,24 @@ _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
4145

4246
export unsafe_logdensityof
4347

48+
"""
49+
MeasureBase.strict_logdensityof(μ, x)
50+
51+
Compute the log-density of the measure `μ` at `x` relative to `rootmeasure(m)`.
52+
In contrast to [`logdensityof(μ, x)`](@ref), this will not take implicit pushforwards
53+
of `μ` (depending on the type of `x`) into account.
54+
"""
55+
function strict_logdensityof end
56+
57+
@inline function strict_logdensityof(μ, x)
58+
result = dynamic(unsafe_logdensityof(μ, x))
59+
_checksupport(insupport(μ, x), result)
60+
end
61+
62+
@inline function strict_logdensityof_rt(::T, ::U) where {T,U}
63+
Core.Compiler.return_type(strict_logdensityof, Tuple{T,U})
64+
end
65+
4466
# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
4567
"""
4668
unsafe_logdensityof(m, x)
@@ -68,14 +90,27 @@ See also `logdensityof`.
6890
end
6991

7092
"""
71-
logdensity_rel(m1, m2, x)
93+
logdensity_rel(μ, ν, x)
7294
73-
Compute the log-density of `m1` relative to `m2` at `x`. This function checks
74-
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
95+
Compute the log-density of `μ` relative to `ν` at `x`. This function checks
96+
whether `x` is in the support of `μ` or `ν` (or both, or neither). If `x` is
7597
known to be in the support of both, it can be more efficient to call
76-
`unsafe_logdensity_rel`.
98+
`unsafe_logdensity_rel`.
99+
"""
100+
function logdensity_rel(μ, ν, x)
101+
strict_logdensity_rel(μ, ν, x)
102+
end
103+
77104
"""
78-
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
105+
MeasureBase.strict_logdensity_rel(μ, ν, x)
106+
107+
Compute the log-density of `μ` relative to `ν` at `x`. In contrast to
108+
[`logdensity_rel(μ, ν, x)`](@ref), this will not take implicit pushforwards
109+
of `μ` and `ν` (depending on the type of `x`) into account.
110+
"""
111+
function strict_logdensity_rel end
112+
113+
@inline function strict_logdensity_rel::M, ν::N, x::X) where {M,N,X}
79114
T = unstatic(
80115
promote_type(
81116
return_type(logdensity_def, (μ, x)),

src/density.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)
163163

164164
density_def::DensityMeasure, x) = densityof.f, x)
165165

166-
function logdensityof::DensityMeasure, x::Any)
166+
function strict_logdensityof::DensityMeasure, x::Any)
167167
integrand, μ_base = μ.f, μ.base
168168

169-
base_logval = logdensityof(μ_base, x)
169+
base_logval = strict_logdensityof(μ_base, x)
170170

171171
T = typeof(base_logval)
172172
U = logdensityof_rt(integrand, x)

src/primitive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ basemeasure(μ::PrimitiveMeasure) = μ
1919

2020
@inline basemeasure_depth(::PrimitiveMeasure) = static(0)
2121

22-
@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
23-
@inline logdensityof(::PrimitiveMeasure, x) = false
22+
@inline strict_logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
23+
@inline strict_logdensityof(::PrimitiveMeasure, x) = false
2424

2525
logdensity_def(::PrimitiveMeasure, x) = static(0.0)
2626

src/primitives/counting.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ struct Counting{T} <: AbstractMeasure
1212
Counting(supp) = new{Core.Typeof(supp)}(supp)
1313
end
1414

15-
@inline function logdensityof::Counting, x::Real)
15+
@inline function strict_logdensityof::Counting, x::Real)
1616
R = float(typeof(x))
1717
insupport(μ, x) ? zero(R) : R(-Inf)
1818
end
1919

20-
@inline logdensityof::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
20+
@inline strict_logdensityof::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
2121

22-
@inline logdensity_def::Counting, x) = logdensityof(μ, x)
22+
@inline logdensity_def::Counting, x) = strict_logdensityof(μ, x)
2323

2424
basemeasure(::Counting) = CountingBase()
2525

0 commit comments

Comments
 (0)