Skip to content

Commit 8e7a4e9

Browse files
committed
Specialize logdensityof for primitive measures
We need maximum performance on these. Also try to preserve floating point types as far as possible.
1 parent 38afd27 commit 8e7a4e9

File tree

6 files changed

+66
-3
lines changed

6 files changed

+66
-3
lines changed

src/primitive.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ basemeasure(μ::PrimitiveMeasure) = μ
1919

2020
@inline basemeasure_depth(::PrimitiveMeasure) = static(0)
2121

22+
@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
23+
@inline logdensityof(::PrimitiveMeasure, x) = false
24+
2225
logdensity_def(::PrimitiveMeasure, x) = static(0.0)
2326

2427
logdensity_def::M, ν::M, x) where {M<:PrimitiveMeasure} = 0.0

src/primitives/counting.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@ struct Counting{T} <: AbstractMeasure
1212
Counting(supp) = new{Core.Typeof(supp)}(supp)
1313
end
1414

15-
function logdensity_def::Counting, x)
16-
insupport(μ, x) ? 0.0 : -Inf
15+
@inline function logdensityof::Counting, x::Real)
16+
R = float(typeof(x))
17+
insupport(μ, x) ? zero(R) : R(-Inf)
1718
end
1819

20+
@inline logdensityof::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
21+
22+
@inline logdensity_def::Counting, x) = logdensityof(μ, x)
23+
1924
basemeasure(::Counting) = CountingBase()
2025

2126
Counting() = Counting(ℤ)

src/primitives/dirac.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ basemeasure(d::Dirac) = CountingBase()
2020

2121
massof(::Dirac) = static(1.0)
2222

23-
logdensity_def::Dirac, x) = 0.0
23+
function logdensityof::Dirac, x::Real)
24+
R = float(typeof(x))
25+
insupport(μ, x) ? zero(R) : R(-Inf)
26+
end
27+
28+
logdensityof::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf
29+
30+
logdensity_def(::Dirac, x::Real) = zero(float(typeof(x)))
31+
logdensity_def(::Dirac, x) = 0.0
2432

2533
Base.rand(::Random.AbstractRNG, T::Type, μ::Dirac) = μ.x
2634

src/primitives/lebesgue.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ insupport(μ::Lebesgue, x) = x ∈ μ.support
6363

6464
insupport(::Lebesgue{RealNumbers}, ::Real) = true
6565

66+
@inline function logdensityof::Lebesgue, x::Real)
67+
R = float(typeof(x))
68+
insupport(μ, x) ? zero(R) : R(-Inf)
69+
end
70+
71+
@inline logdensityof::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf
72+
6673
massof(::Lebesgue{RealNumbers}, s::Interval) = width(s)
6774

6875
# Example:

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("test_aqua.jl")
1111

1212
include("static.jl")
1313

14+
include("test_primitive.jl")
1415
include("test_standard.jl")
1516
include("test_basics.jl")
1617

test/test_primitive.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Test
2+
3+
using MeasureBase
4+
using MeasureBase: insupport as measure_insupport
5+
6+
using DensityInterface: logdensityof
7+
8+
@testset "primitive" begin
9+
for (m, x) in [
10+
(MeasureBase.LebesgueBase(), -1.0f0),
11+
(MeasureBase.LebesgueBase(), 1.0f0),
12+
(Lebesgue(), 1.0f0),
13+
(Lebesgue(), -1.0f0),
14+
(MeasureBase.CountingBase(), -1.0f0),
15+
(MeasureBase.CountingBase(), 1.0f0),
16+
(Counting(), 2),
17+
(Counting(), 2.0f0),
18+
(Counting(), 1.5f0),
19+
(Dirac(4.2), 4.2f0),
20+
(Dirac(4.2), -1.0f0),
21+
(Dirac([1, 2, 3]), [1, 2, 3]),
22+
(Dirac([4, 5]), [4, 5]),
23+
]
24+
@testset "$(nameof(typeof(m)))" begin
25+
for x in [-Inf, -1.2, -1, 0, 0.0, 1 // 2, 0.5, 1, 1.0, 2, 2.3]
26+
if measure_insupport(m, x)
27+
@test @inferred(logdensityof(m, x)) 0
28+
if x isa Real
29+
ld = logdensityof(m, x)
30+
ld isa float(typeof(x))
31+
end
32+
@test MeasureBase.unsafe_logdensityof(m, x) 0
33+
else
34+
@test @inferred(logdensityof(m, x)) -Inf
35+
end
36+
end
37+
end
38+
end
39+
end

0 commit comments

Comments
 (0)