Skip to content

Commit e242314

Browse files
committed
Specialize logdensityof for most combinators
Improves type stability and numerical type propagation. Specialize logdensityof until a full refactor of the density engine.
1 parent 9703ab3 commit e242314

File tree

6 files changed

+79
-51
lines changed

6 files changed

+79
-51
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ function logdensity_def end
109109
using Compat
110110

111111
using IrrationalConstants
112+
using IrrationalConstants: loghalf
112113

113114
include("static.jl")
114115
include("smf.jl")

src/combinators/half.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
1919
return abs(rand(rng, T, unhalf(μ)))
2020
end
2121

22+
function logdensityof::Half, x)
23+
ld = logdensityof(unhalf(μ), x) - loghalf
24+
return x 0 ? ld : oftype(ld, -Inf)
25+
end
26+
2227
logdensity_def::Half, x) = logdensity_def(unhalf(μ), x)
2328

2429
@inline function insupport(d::Half, x)

src/combinators/power.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,27 @@ params(d::PowerMeasure) = params(first(marginals(d)))
7878
basemeasure(d.parent)^d.axes
7979
end
8080

81-
@inline function logdensity_def(d::PowerMeasure{M}, x) where {M}
82-
parent = d.parent
83-
sum(x) do xj
84-
logdensity_def(parent, xj)
81+
for func in [:logdensityof, :logdensity_def]
82+
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
83+
parent = d.parent
84+
sum(x) do xj
85+
$func(parent, xj)
86+
end
8587
end
86-
end
8788

88-
@inline function logdensity_def(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
89-
parent = d.parent
90-
sum(1:N) do j
91-
@inbounds logdensity_def(parent, x[j])
89+
@eval @inline function $func(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
90+
parent = d.parent
91+
sum(1:N) do j
92+
@inbounds $func(parent, x[j])
93+
end
9294
end
93-
end
9495

95-
@inline function logdensity_def(
96-
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
97-
x,
98-
) where {M,N}
99-
static(0.0)
96+
@eval @inline function $func(
97+
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
98+
x,
99+
) where {M,N}
100+
static(0.0)
101+
end
100102
end
101103

102104
@inline function insupport::PowerMeasure, x)

src/combinators/product.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ function _rand_product(
7272
end |> collect
7373
end
7474

75-
@inline function logdensity_def(d::AbstractProductMeasure, x)
76-
mapreduce(logdensity_def, +, marginals(d), x)
75+
for func in [:logdensityof, :logdensity_def]
76+
@eval @inline function $func(d::AbstractProductMeasure, x)
77+
mapreduce($func, +, marginals(d), x)
78+
end
7779
end
7880

7981
struct ProductMeasure{M} <: AbstractProductMeasure
@@ -88,27 +90,37 @@ function Pretty.tile(d::ProductMeasure{T}) where {T<:Tuple}
8890
Pretty.list_layout(Pretty.tile.([marginals(d)...]), sep = "")
8991
end
9092

91-
# For tuples, `mapreduce` has trouble with type inference
92-
@inline function logdensity_def(d::ProductMeasure{T}, x) where {T<:Tuple}
93-
ℓs = map(logdensity_def, marginals(d), x)
94-
sum(ℓs)
95-
end
96-
97-
@generated function logdensity_def(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T}
93+
@eval @generated function _product_gen_impl(
94+
::Val{func},
95+
d::ProductMeasure{NamedTuple{N,T}},
96+
x,
97+
) where {func,N,T}
9898
k1 = QuoteNode(first(N))
9999
q = quote
100100
m = marginals(d)
101-
= logdensity_def(getproperty(m, $k1), getproperty(x, $k1))
101+
= $func(getproperty(m, $k1), getproperty(x, $k1))
102102
end
103103
for k in Base.tail(N)
104104
k = QuoteNode(k)
105-
qk = :(ℓ += logdensity_def(getproperty(m, $k), getproperty(x, $k)))
105+
qk = :(ℓ += $func(getproperty(m, $k), getproperty(x, $k)))
106106
push!(q.args, qk)
107107
end
108108

109109
return q
110110
end
111111

112+
for func in [:logdensityof, :logdensity_def]
113+
# For tuples, `mapreduce` has trouble with type inference
114+
@eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple}
115+
ℓs = map($func, marginals(d), x)
116+
sum(ℓs)
117+
end
118+
119+
@eval function $func(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T}
120+
_product_gen_impl(Val($func), d, x)
121+
end
122+
end
123+
112124
# @generated function basemeasure(d::ProductMeasure{NamedTuple{N,T}}, x) where {N,T}
113125
# q = quote
114126
# m = marginals(d)

src/combinators/spikemixture.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,17 @@ end
2121
SpikeMixture(basemeasure.m), static(1.0), static(1.0))
2222
end
2323

24-
@inline function logdensity_def::SpikeMixture, x)
25-
if iszero(x)
26-
return log.s)
27-
else
28-
return log.w) + logdensity_def.m, x)
24+
for func in [:logdensityof, :logdensity_def]
25+
@eval @inline function $func::SpikeMixture, x)
26+
R1 = Core.Compiler.return_type(log, Tuple{typeof.s)})
27+
R2 = Core.Compiler.return_type(log, Tuple{typeof.w)})
28+
R3 = Core.Compiler.return_type($func, Tuple{typeof.m),typeof(x)})
29+
R = promote_type(R1, R2, R3)
30+
if iszero(x)
31+
return convert(R, log.s))::R
32+
else
33+
return convert(R, log.w) + $func.m, x))::R
34+
end
2935
end
3036
end
3137

src/combinators/transformedmeasure.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,27 +55,29 @@ end
5555
# logdensity_rel(pushfwd(f, inv_f, ν.origin, WithVolCorr()), β.origin, x)
5656
# end
5757

58-
@inline function logdensity_def::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
59-
f = ν.f
60-
finv = ν.finv
61-
x_orig, inv_ladj = with_logabsdet_jacobian(unwrap(finv), y)
62-
logd_orig = logdensity_def.origin, x_orig)
63-
logd = float(logd_orig + inv_ladj)
64-
neginf = oftype(logd, -Inf)
65-
return ifelse(
66-
# Zero density wins against infinite volume:
67-
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
68-
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
69-
# Return constant -Inf to prevent problems with ForwardDiff:
70-
(isfinite(logd_orig) && (inv_ladj == -Inf)),
71-
neginf,
72-
logd,
73-
)
74-
end
58+
for func in [:logdensityof, :logdensity_def]
59+
@eval @inline function $func::PushforwardMeasure{F,I,M,<:WithVolCorr}, y) where {F,I,M}
60+
f = ν.f
61+
finv = ν.finv
62+
x_orig, inv_ladj = with_logabsdet_jacobian(unwrap(finv), y)
63+
logd_orig = $func.origin, x_orig)
64+
logd = float(logd_orig + inv_ladj)
65+
neginf = oftype(logd, -Inf)
66+
return ifelse(
67+
# Zero density wins against infinite volume:
68+
(isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) ||
69+
# Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ?
70+
# Return constant -Inf to prevent problems with ForwardDiff:
71+
(isfinite(logd_orig) && (inv_ladj == -Inf)),
72+
neginf,
73+
logd,
74+
)
75+
end
7576

76-
@inline function logdensity_def::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M}
77-
x = ν.finv(y)
78-
return logdensity_def.origin, x)
77+
@eval @inline function $func::PushforwardMeasure{F,I,M,<:NoVolCorr}, y) where {F,I,M}
78+
x = ν.finv(y)
79+
return $func.origin, x)
80+
end
7981
end
8082

8183
insupport::PushforwardMeasure, y) = insupport.origin, ν.finv(y))

0 commit comments

Comments
 (0)