Skip to content

Commit 0509669

Browse files
authored
Standard-measures (#56)
Add `StdUniform`, `StdNormal`, `StdExponential`
1 parent c6b030d commit 0509669

File tree

10 files changed

+71
-13
lines changed

10 files changed

+71
-13
lines changed

Project.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.9.2"
4+
version = "0.9.3"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
88
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
99
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
1010
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1111
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
12+
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1415
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
@@ -28,6 +29,7 @@ ConstructionBase = "1.3"
2829
DensityInterface = "0.4"
2930
FillArrays = "0.12, 0.13"
3031
IfElse = "0.1"
32+
IrrationalConstants = "0.1"
3133
LogExpFunctions = "0.3"
3234
LogarithmicNumbers = "1"
3335
MappedArrays = "0.4"
@@ -40,9 +42,6 @@ julia = "1.3"
4042

4143
[extras]
4244
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
43-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
44-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
45-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4645

4746
[targets]
48-
test = ["Test", "Aqua", "LinearAlgebra", "Statistics"]
47+
test = ["Aqua"]

src/MeasureBase.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
module MeasureBase
22

3-
const logtwo = log(2.0)
4-
53
using Random
64
import Random: rand!
75
import Random: gentype
@@ -94,6 +92,8 @@ function logdensity_def end
9492

9593
using Compat
9694

95+
using IrrationalConstants
96+
9797
include("schema.jl")
9898
include("splat.jl")
9999
include("proxies.jl")
@@ -124,6 +124,10 @@ include("combinators/smart-constructors.jl")
124124
include("combinators/powerweighted.jl")
125125
include("combinators/conditional.jl")
126126

127+
include("standard/stdnormal.jl")
128+
include("standard/stduniform.jl")
129+
include("standard/stdexponential.jl")
130+
127131
include("rand.jl")
128132

129133
include("density.jl")

src/combinators/half.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212
unhalf::Half) = μ.parent
1313

1414
@inline function basemeasure::Half)
15-
weightedmeasure(static(logtwo), basemeasure(unhalf(μ)))
15+
weightedmeasure(logtwo, basemeasure(unhalf(μ)))
1616
end
1717

1818
function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}

src/combinators/smart-constructors.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ superpose(nt::NamedTuple) = SuperpositionMeasure(nt)
8080

8181
function superpose::T, ν::T) where {T<:AbstractMeasure}
8282
if μ == ν
83-
return weightedmeasure(static(logtwo), μ)
83+
return weightedmeasure(logtwo, μ)
8484
else
8585
return superpose((μ, ν))
8686
end
@@ -127,7 +127,7 @@ function kernel(d::PowerMeasure)
127127
end
128128

129129
function kernel(f)
130-
T = Core.Compiler.return_type(f, Tuple{Any} )
130+
T = Core.Compiler.return_type(f, Tuple{Any})
131131
_kernel(f, T)
132132
end
133133

src/density.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ Define a new measure in terms of a log-density `f` over some measure `base`.
100100
"""
101101
∫exp(f::Function, μ) = (logfuncdensity(f), μ)
102102

103-
104103
"""
105104
logdensityof(m::AbstractMeasure, x)
106105

src/kernel.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ end
7676

7777
(k::AbstractTransitionKernel)(; kwargs...) = k(NamedTuple(kwargs))
7878

79-
8079
"""
8180
For any `k::TransitionKernel`, `basekernel` is expected to satisfy
8281
```
@@ -107,7 +106,6 @@ function Pretty.tile(k::K) where {K<:AbstractTransitionKernel}
107106
)
108107
end
109108

110-
111109
const kleisli = kernel
112110

113111
export kleisli

src/standard/stdexponential.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
struct StdExponential <: AbstractMeasure end
2+
3+
export StdExponential
4+
5+
insupport(d::StdExponential, x) = x zero(x)
6+
7+
@inline logdensity_def(::StdExponential, x) = -x
8+
@inline basemeasure(::StdExponential) = Lebesgue()
9+
10+
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T}
11+
randexp(rng, T)
12+
end

src/standard/stdnormal.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
struct StdNormal <: AbstractMeasure end
2+
3+
export StdNormal
4+
5+
insupport(d::StdNormal, x) = true
6+
insupport(d::StdNormal) = Returns(true)
7+
8+
@inline logdensity_def(::StdNormal, x) = -x^2 / 2
9+
@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), Lebesgue(ℝ))
10+
11+
Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdNormal) where {T} = randn(rng, T)

src/standard/stduniform.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
struct StdUniform <: AbstractMeasure end
2+
3+
export StdUniform
4+
5+
insupport(d::StdUniform, x) = zero(x) x one(x)
6+
7+
@inline logdensity_def(::StdUniform, x) = zero(x)
8+
@inline basemeasure(::StdUniform) = Lebesgue()
9+
10+
Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T)

test/runtests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ test_measures = [
3737
Dirac(0) + Dirac(1)
3838
Dirac(0.0) + Lebesgue(ℝ)
3939
SpikeMixture(Lebesgue(ℝ), 0.2)
40+
41+
StdNormal()
42+
StdNormal()^3
43+
StdNormal()^(2, 3)
44+
3 * StdNormal()
45+
0.2 * StdNormal() + 0.8 * Dirac(0.0)
46+
Dirac(0.0) + StdNormal()
47+
SpikeMixture(StdNormal(), 0.2)
48+
49+
StdUniform()
50+
StdUniform()^3
51+
StdUniform()^(2, 3)
52+
3 * StdUniform()
53+
0.2 * StdUniform() + 0.8 * Dirac(0.0)
54+
Dirac(0.0) + StdUniform()
55+
SpikeMixture(StdUniform(), 0.2)
56+
57+
StdExponential()^3
58+
StdExponential()^(2, 3)
59+
3 * StdExponential()
60+
StdExponential()
61+
0.2 * StdExponential() + 0.8 * Dirac(0.0)
62+
Dirac(0.0) + StdExponential()
63+
SpikeMixture(StdExponential(), 0.2)
64+
4065
# d ⊙ d
4166
]
4267

0 commit comments

Comments
 (0)