Skip to content

Commit 38afd27

Browse files
committed
Specialize and test logdensityof for standard measures
We need maximum performance for these.
1 parent 9042b49 commit 38afd27

File tree

8 files changed

+54
-10
lines changed

8 files changed

+54
-10
lines changed

src/standard/stdexponential.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ struct StdExponential <: StdMeasure end
22

33
export StdExponential
44

5-
insupport(d::StdExponential, x) = x zero(x)
5+
insupport(::StdExponential, x) = x zero(x)
6+
7+
@inline function logdensityof(::StdExponential, x)
8+
R = float(typeof(x))
9+
x zero(R) ? convert(R, -x) : R(-Inf)
10+
end
611

712
@inline logdensity_def(::StdExponential, x) = -x
813
@inline basemeasure(::StdExponential) = LebesgueBase()

src/standard/stdlogistic.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ export StdLogistic
44

55
@inline insupport(d::StdLogistic, x) = true
66

7-
@inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u))
7+
@inline logdensityof(::StdLogistic, x) = (u = -abs(x); u - 2 * log1pexp(u))
8+
9+
@inline logdensity_def(::StdLogistic, x) = logdensityof(StdLogistic(), x)
810
@inline basemeasure(::StdLogistic) = LebesgueBase()
911

1012
@inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x)

src/standard/stdnormal.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using SpecialFunctions: erfc, erfcinv
2-
using IrrationalConstants: invsqrt2
2+
using IrrationalConstants: invsqrt2, log2π
33

44
struct StdNormal <: StdMeasure end
55

66
export StdNormal
77

8-
@inline insupport(d::StdNormal, x) = true
8+
@inline insupport(::StdNormal, x) = true
9+
10+
@inline logdensityof(::StdNormal, x) = (-x^2 - log2π) / 2
911

1012
@inline logdensity_def(::StdNormal, x) = -x^2 / 2
1113
@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), LebesgueBase())

src/standard/stduniform.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ struct StdUniform <: StdMeasure end
22

33
export StdUniform
44

5-
insupport(d::StdUniform, x) = zero(x) x one(x)
5+
insupport(::StdUniform, x) = zero(x) x one(x)
6+
7+
@inline function logdensityof(::StdUniform, x)
8+
R = float(typeof(x))
9+
zero(x) x one(x) ? zero(R) : R(-Inf)
10+
end
611

712
@inline logdensity_def(::StdUniform, x) = zero(x)
813
@inline basemeasure(::StdUniform) = LebesgueBase()

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
44
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
55
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
6+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
67
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
78
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
89
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ include("test_aqua.jl")
1111

1212
include("static.jl")
1313

14+
include("test_standard.jl")
1415
include("test_basics.jl")
1516

1617
include("getdof.jl")

test/test_basics.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,13 @@ end
124124
@test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3, 1), (2, 0))
125125
end
126126

127-
Normal() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ))
127+
NormalMeasure() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ))
128128

129129
@testset "Half" begin
130-
HalfNormal() = Half(Normal())
130+
HalfNormal() = Half(NormalMeasure())
131131
@test logdensityof(HalfNormal(), -0.2) == -Inf
132-
@test logdensity_def(HalfNormal(), 0.2) == logdensity_def(Normal(), 0.2)
133-
@test densityof(HalfNormal(), 0.2) 2 * densityof(Normal(), 0.2)
132+
@test logdensity_def(HalfNormal(), 0.2) == logdensity_def(NormalMeasure(), 0.2)
133+
@test densityof(HalfNormal(), 0.2) 2 * densityof(NormalMeasure(), 0.2)
134134
end
135135

136136
@testset "Likelihood" begin
@@ -218,7 +218,7 @@ end
218218
@test log(f(x)) x^2
219219
end
220220

221-
let f = log𝒹(∫exp(x -> x^2, Normal()), Normal())
221+
let f = log𝒹(∫exp(x -> x^2, NormalMeasure()), NormalMeasure())
222222
@test f(x) x^2
223223
end
224224
end

test/test_standard.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using Test
2+
using MeasureBase
3+
using MeasureBase: insupport as measure_insupport
4+
5+
using DensityInterface: logdensityof
6+
import Distributions: insupport as dist_insupport
7+
using Distributions: Normal, Exponential, Logistic, Uniform
8+
9+
@testset "standard" begin
10+
for (m, d) in [
11+
(StdUniform(), Uniform()),
12+
(StdExponential(), Exponential()),
13+
(StdLogistic(), Logistic()),
14+
(StdNormal(), Normal()),
15+
]
16+
@testset "$(nameof(typeof(m)))" begin
17+
for x in [-Inf, -1.2, -1, 0, 0.0, 1 // 2, 0.5, 1, 1.0, 2, 2.3]
18+
@test @inferred(logdensityof(m, x)) logdensityof(d, x)
19+
@test @inferred(logdensityof(m, x)) logdensity_rel(m, rootmeasure(m), x)
20+
@test measure_insupport(m, x) dist_insupport(d, x)
21+
if measure_insupport(m, x)
22+
@test @inferred(MeasureBase.unsafe_logdensityof(m, x))
23+
logdensityof(d, x)
24+
end
25+
end
26+
end
27+
end
28+
end

0 commit comments

Comments
 (0)