Skip to content

Commit 352296f

Browse files
authored
Avoid unneeded array traversals (#108)
## `insupport` for product measures `insupport` currently has this method: ```julia @inline function insupport(d::AbstractProductMeasure, x::AbstractArray) mar = marginals(d) for (j, mj) in enumerate(mar) dynamic(insupport(mj, x[j])) || return false end return true end ``` But we can often infer that we're in the support, so we can save the cost of traversing the array. This PR changes this to ```julia @inline function insupport(d::AbstractProductMeasure, x::AbstractArray) mar = marginals(d) # We might get lucky and know statically that everything is inbounds T = Core.Compiler.return_type(insupport, Tuple{eltype(mar),eltype(x)}) T <: True || all(zip(x, mar)) do (xj, mj) insupport(mj, xj) == true end end ``` ## `logdensity_def` for powers of primitive measures Again, we're currently traversing the array unnecessarily. This PR adds these methods ```julia logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0) # To avoid ambiguities function logdensity_def( ::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}}, x, ) where {P<:PrimitiveMeasure,N} static(0.0) end ``` ## `istrue` Lots of code expecting a Bool doesn't work for StaticBools. In these cases, we instead need to check e.g. `p == true`. I've added `istrue(p) = p == true` to make this a little simpler
1 parent fe894f7 commit 352296f

File tree

6 files changed

+29
-15
lines changed

6 files changed

+29
-15
lines changed

src/MeasureBase.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ using LogExpFunctions: logsumexp, logistic, logit
9090

9191
@deprecate instance_type(x) Core.Typeof(x) false
9292

93+
# Mostly useful for StaticBools
94+
istrue(p) = p == true
95+
9396
"""
9497
`logdensity_def` is the standard way to define a log-density for a new measure.
9598
Note that this definition does not include checking for membership in the

src/combinators/power.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,13 @@ function checked_arg(μ::PowerMeasure, x::Any)
130130
end
131131

132132
massof(m::PowerMeasure) = massof(m.parent)^prod(m.axes)
133+
134+
logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)
135+
136+
# To avoid ambiguities
137+
function logdensity_def(
138+
::PowerMeasure{P,Tuple{Vararg{Base.OneTo{Static.StaticInt{0}},N}}},
139+
x,
140+
) where {P<:PrimitiveMeasure,N}
141+
static(0.0)
142+
end

src/combinators/product.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,10 +212,11 @@ end
212212

213213
@inline function insupport(d::AbstractProductMeasure, x::AbstractArray)
214214
mar = marginals(d)
215-
for (j, mj) in enumerate(mar)
216-
dynamic(insupport(mj, x[j])) || return false
215+
# We might get lucky and know statically that everything is inbounds
216+
T = Core.Compiler.return_type(insupport, Tuple{eltype(mar),eltype(x)})
217+
T <: True || all(zip(x, mar)) do (xj, mj)
218+
insupport(mj, xj) == true
217219
end
218-
return true
219220
end
220221

221222
@inline function insupport(d::AbstractProductMeasure, x)

src/combinators/superpose.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
7070
@inline function density_def(s::SuperpositionMeasure{Tuple{A,B}}, x) where {A,B}
7171
(μ, ν) = s.components
7272

73-
insupport(μ, x) || return exp(ULogarithmic, logdensity_def(ν, x))
74-
insupport(ν, x) || return exp(ULogarithmic, logdensity_def(μ, x))
73+
istrue(insupport(μ, x)) || return exp(ULogarithmic, logdensity_def(ν, x))
74+
istrue(insupport(ν, x)) || return exp(ULogarithmic, logdensity_def(μ, x))
7575

7676
α = basemeasure(μ)
7777
β = basemeasure(ν)
@@ -110,8 +110,8 @@ end
110110
) where {T<:(SuperpositionMeasure{Tuple{A,B}} where {A,B})}
111111
(μ, ν) = s.components
112112

113-
insupport(μ, x) == true || return logdensity_rel(ν, β, x)
114-
insupport(ν, x) == true || return logdensity_rel(μ, β, x)
113+
istrue(insupport(μ, x)) || return logdensity_rel(ν, β, x)
114+
istrue(insupport(ν, x)) || return logdensity_rel(μ, β, x)
115115
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
116116
end
117117

@@ -121,8 +121,8 @@ end
121121
x,
122122
) where {A,B}
123123
(μ, ν) = s.components
124-
insupport(μ, x) == true || return logdensity_rel(ν, β, x)
125-
insupport(ν, x) == true || return logdensity_rel(μ, β, x)
124+
istrue(insupport(μ, x)) || return logdensity_rel(ν, β, x)
125+
istrue(insupport(ν, x)) || return logdensity_rel(μ, β, x)
126126
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
127127
end
128128

src/interface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function dynamic_basemeasure_depth(μ::M) where {M}
3030
π = proxy(μ)
3131
if static_hasmethod(basemeasure, Tuple{typeof(π)})
3232
basemeasure(π) == basemeasure(μ) && return dynamic_basemeasure_depth(π)
33-
end
33+
end
3434
end
3535
β = basemeasure(μ)
3636
depth = 0
@@ -112,17 +112,17 @@ function test_smf(μ, n = 100)
112112
p = rand(n)
113113
p .+= 0:n-1
114114
p .*= inv(n)
115-
115+
116116
F(x) = smf(μ, x)
117117
Finv(p) = invsmf(μ, p)
118-
118+
119119
@assert issorted(p)
120120
x = invsmf.(μ, p)
121121
@test issorted(x)
122122
@test all(insupport(μ), x)
123-
123+
124124
@test all((Finv F).(x) .≈ x)
125-
125+
126126
for j in 1:n
127127
a = rand()
128128
b = rand()

src/primitive.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ basemeasure(μ::PrimitiveMeasure) = μ
1919

2020
@inline basemeasure_depth(::PrimitiveMeasure) = static(0)
2121

22-
logdensity_def(μ::PrimitiveMeasure, x) = static(0.0)
22+
logdensity_def(::PrimitiveMeasure, x) = static(0.0)
2323

2424
logdensity_def::M, ν::M, x) where {M<:PrimitiveMeasure} = 0.0
2525

0 commit comments

Comments
 (0)