Skip to content

Commit e603347

Browse files
authored
Dev (#166)
* Dirichlet(k::Integer, α) = Dirichlet(Fill(α, k)) * export TransformVariables as TV * drop redundant import * 0.0 => zero(Float64) * drop outdated Dists.logpdf * update StudentT * drop redundant import * update Uniform * bump MeasureBase version * reworking beta * small update to StudentT * basemeasure for discrete Distributions * using LogExpFunctions => import LogExpFunctions * quoteof(::Chain) * prettyprinting and chain-mucking * Some refactoring for Markov chains * import MeasureBase: ≪ * version bound for PrettyPrinting * copy(rng) might change its type (e.g. GLOBAL_RNG) * tests pass * cleaning up * more cleanup * bump versions * change import..as so it's 1.5 friendly * bump version
1 parent 73e0a07 commit e603347

File tree

9 files changed

+322
-95
lines changed

9 files changed

+322
-95
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.13.0"
4+
version = "0.13.1"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1111
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
1212
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1313
InfiniteArrays = "4858937d-0d70-526a-a4dd-2d5cb5dd786c"
14+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1415
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
@@ -21,6 +22,7 @@ MeasureBase = "fa1605e6-acd5-459c-a1e6-7e635759db14"
2122
NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50"
2223
NestedTuples = "a734d2a7-8d68-409b-9419-626914d4061d"
2324
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
25+
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
2426
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2527
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2628
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -35,16 +37,17 @@ ConstructionBase = "1.3"
3537
Distributions = "0.25"
3638
DynamicIterators = "0.4"
3739
FillArrays = "0.12"
38-
InfiniteArrays = "0.11"
40+
InfiniteArrays = "0.11, 0.12"
3941
KeywordCalls = "0.2"
4042
LogExpFunctions = "0.3.3"
4143
MLStyle = "0.4"
4244
MacroTools = "0.5"
4345
MappedArrays = "0.4"
44-
MeasureBase = "0.4"
46+
MeasureBase = "0.5"
4547
NamedTupleTools = "0.13"
4648
NestedTuples = "0.3"
4749
PositiveFactorizations = "0.2"
50+
PrettyPrinting = "0.3"
4851
Reexport = "1"
4952
SpecialFunctions = "1"
5053
StatsFuns = "0.9"

src/MeasureTheory.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import Base
1313
import Distributions
1414
const Dists = Distributions
1515

16-
# export TV
16+
export TV
1717
export
1818
export sampletype
1919
export For
@@ -38,12 +38,17 @@ using ConstructionBase
3838
using Accessors
3939
using StatsFuns
4040
using SpecialFunctions
41-
using LogExpFunctions
4241

42+
import LogExpFunctions
4343
import NamedTupleTools
4444

4545
import MeasureBase: testvalue, logdensity, density, basemeasure, kernel, params, ∫
46-
import MeasureBase: affine, supportdim
46+
import MeasureBase: affine, supportdim,
47+
using MeasureBase: constructor
48+
49+
import PrettyPrinting
50+
51+
const Pretty = PrettyPrinting
4752

4853
import Base: rand
4954

@@ -65,9 +70,6 @@ sampletype(μ::AbstractMeasure) = typeof(testvalue(μ))
6570

6671
# sampletype(μ::AbstractMeasure) = sampletype(basemeasure(μ))
6772

68-
69-
70-
7173
import Distributions: logpdf, pdf
7274

7375
export pdf, logpdf
@@ -94,6 +96,12 @@ const AFFINEPARS = [
9496
(,)
9597
]
9698

99+
xlogy(x::Number, y::Number) = LogExpFunctions.xlogy(x, y)
100+
xlogy(x, y) = x * log(y)
101+
102+
xlog1py(x::Number, y::Number) = LogExpFunctions.xlog1py(x, y)
103+
xlog1py(x, y) = x * log1p(y)
104+
97105

98106
include("const.jl")
99107
# include("traits.jl")
@@ -104,6 +112,9 @@ include("combinators/affine.jl")
104112
include("combinators/weighted.jl")
105113
include("combinators/product.jl")
106114
include("combinators/transforms.jl")
115+
116+
include("resettable-rng.jl")
117+
include("realized.jl")
107118
include("combinators/chain.jl")
108119

109120
include("distributions.jl")

src/combinators/chain.jl

Lines changed: 43 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
using ConcreteStructs
22
using DynamicIterators
3-
using DynamicIterators: dub
3+
import DynamicIterators: dub, dyniterate, evolve
44
using Base.Iterators: SizeUnknown, IsInfinite
5-
import DynamicIterators: dyniterate, evolve
65

76
import MeasureBase: For
87

8+
9+
910
export Chain
1011

1112
@concrete terse struct Chain{K,M} <: AbstractMeasure
1213
κ::K
1314
μ::M
1415
end
1516

17+
Pretty.quoteof(c::Chain) = :(Chain($(Pretty.quoteof(c.κ)), $(Pretty.quoteof(c.μ))))
18+
19+
Base.length(::Chain) =
20+
1621
function basemeasure(mc::Chain)
1722
Chain(basemeasure mc.κ, basemeasure(mc.μ))
1823
end
@@ -39,73 +44,19 @@ dyniterate(E::Chain, ::Nothing) = dub(evolve(E))
3944
Base.iterate(E::Chain) = dyniterate(E, nothing)
4045
Base.iterate(E::Chain, value) = dyniterate(E, value)
4146

42-
function DynamicIterators.dyniterate(r::Chain, (u,rng)::Sample)
43-
μ = r.κ(u)
44-
u = rand(rng, μ)
45-
return u, Sample(u, rng)
47+
function DynamicIterators.dyniterate(r::Chain, (x,rng)::Sample)
48+
μ = r.κ(x)
49+
y = rand(rng, μ)
50+
return y), Sample(y, rng)
4651
end
52+
4753
Base.IteratorSize(::Chain) = IsInfinite()
4854
Base.IteratorSize(::Type{Chain}) = IsInfinite()
4955

5056

51-
@concrete terse struct Realized{R,S,T} <: DynamicIterators.DynamicIterator
52-
rng::ResettableRNG{R,S}
53-
iter::T
54-
end
55-
56-
Base.IteratorEltype(mc::Realized) = Base.HasEltype()
57-
58-
function Base.eltype(::Type{Rz}) where {R,S,T,Rz <: Realized{R,S,T}}
59-
eltype(T)
60-
end
61-
62-
Base.length(r::Realized) = length(r.iter)
63-
64-
Base.size(r::Realized) = size(r.iter)
65-
66-
Base.IteratorSize(::Type{Rz}) where {R,S,T, Rz <: Realized{R,S,T}} = Base.IteratorSize(T)
67-
Base.IteratorSize(r::Rz) where {R,S,T, Rz <: Realized{R,S,T}} = Base.IteratorSize(r.iter)
68-
69-
70-
function Base.iterate(rv::Realized{R,S,T}) where {R,S,T}
71-
if static_hasmethod(evolve, Tuple{T})
72-
dyniterate(rv, nothing)
73-
else
74-
!isnothing(rv.rng.seed) && reset!(rv.rng)
75-
μ,s = iterate(rv.iter)
76-
x = rand(rv.rng, μ)
77-
x,s
78-
end
79-
end
80-
81-
82-
function Base.iterate(rv::Realized{R,S,T}, s) where {R,S,T}
83-
if static_hasmethod(evolve, Tuple{T})
84-
dyniterate(rv, s)
85-
else
86-
μs = iterate(rv.iter, s)
87-
isnothing(μs) && return nothing
88-
(μ,s) = μs
89-
x = rand(rv.rng, μ)
90-
return x,s
91-
end
92-
end
93-
94-
95-
function dyniterate(rv::Realized, ::Nothing)
96-
!isnothing(rv.rng.seed) && reset!(rv.rng)
97-
μ = evolve(rv.iter)
98-
x = rand(rv.rng, μ)
99-
x, Sample(x, rv.rng)
100-
end
101-
function dyniterate(rv::Realized, u::Sample)
102-
dyniterate(rv.iter, u)
103-
end
104-
10557
function Base.rand(rng::AbstractRNG, T::Type, chain::Chain)
106-
seed = rand(rng, UInt)
107-
r = ResettableRNG(rng, seed)
108-
return Realized(r, chain)
58+
r = ResettableRNG(rng)
59+
return RealizedSamples(r, chain)
10960
end
11061

11162
###############################################################################
@@ -115,38 +66,53 @@ end
11566

11667
@concrete terse struct DynamicFor{T,K,S} <: AbstractMeasure
11768
κ ::K
118-
sampler :: S
69+
iter :: S
11970
end
12071

121-
function DynamicFor::K,sampler::S) where {K,S}
122-
T = typeof(κ(first(sampler)))
123-
DynamicFor{T,K,S}(κ,sampler)
72+
Pretty.quoteof(r::DynamicFor) = :(DynamicFor($(Pretty.quoteof(r.κ)), $(Pretty.quoteof(r.iter))))
73+
74+
function DynamicFor::K,iter::S) where {K,S}
75+
T = typeof(κ(first(iter)))
76+
DynamicFor{T,K,S}(κ,iter)
77+
end
78+
79+
function Base.rand(rng::AbstractRNG, T::Type, df::DynamicFor)
80+
r = ResettableRNG(rng)
81+
return RealizedSamples(r, df)
82+
end
83+
84+
function logdensity(df::DynamicFor, y)
85+
= 0.0
86+
for (xj, yj) in zip(df.iter, y)
87+
+= logdensity(df.κ(xj), yj)
88+
end
89+
return
12490
end
12591

12692
Base.eltype(::Type{D}) where {T,D<:DynamicFor{T}} = eltype(T)
12793

12894
Base.IteratorEltype(d::DynamicFor) = Base.HasEltype()
12995

130-
Base.IteratorSize(d::DynamicFor) = Base.IteratorSize(d.sampler)
96+
Base.IteratorSize(d::DynamicFor) = Base.IteratorSize(d.iter)
13197

13298
function Base.iterate(d::DynamicFor)
133-
(x,s) = iterate(d.sampler)
99+
(x,s) = iterate(d.iter)
134100
(d.κ(x), s)
135101
end
136102

137103
function Base.iterate(d::DynamicFor, s)
138-
(x,s) = iterate(d.sampler, s)
104+
(x,s) = iterate(d.iter, s)
139105
(d.κ(x), s)
140106
end
141107

142-
Base.length(d::DynamicFor) = length(d.sampler)
108+
Base.length(d::DynamicFor) = length(d.iter)
143109

144110

145-
For(f, r::Realized) = DynamicFor(f,r)
111+
For(f, r::Realized) = DynamicFor(f, copy(r))
146112

147113
function Base.rand(rng::AbstractRNG, dfor::DynamicFor)
148-
seed = rand(rng, UInt)
149-
return Realized(seed, copy(rng), dfor)
114+
r = ResettableRNG(rng)
115+
return RealizedSamples(r, dfor)
150116
end
151117

152118
function dyniterate(df::DynamicFor, st, args...)
@@ -158,11 +124,11 @@ For(f, it::DynamicIterator) = DynamicFor(f, it)
158124

159125
For(f, it::DynamicFor) = DynamicFor(f, it)
160126

161-
function dyniterate(fr::DynamicFor, state)
162-
ϕ = dyniterate(fr.iter, state)
127+
function dyniterate(df::DynamicFor, state)
128+
ϕ = dyniterate(df.iter, state)
163129
ϕ === nothing && return nothing
164130
u, state = ϕ
165-
fr.f(u), state
131+
df.f(u), state
166132
end
167133

168134
function Base.collect(r::Realized)

src/distributions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ function basemeasure(μ::Dists.Distribution{Dists.Univariate,Dists.Continuous})
3333
return Lebesgue(ℝ)
3434
end
3535

36+
function basemeasure::Dists.Distribution{Dists.Univariate,Dists.Discrete})
37+
return CountingMeasure(ℤ)
38+
end
39+
3640
(::typeof(identity), ::Dists.Distribution) = 1.0
3741

3842
logdensity::Dists.Distribution, x) = Dists.logpdf(μ,x)

src/parameterized/beta.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
export Beta
44

5-
@parameterized Beta(α,β) Lebesgue(𝕀)
5+
@parameterized Beta(α,β)
66

77
@kwstruct Beta(α, β)
88

@@ -15,8 +15,20 @@ export Beta
1515

1616
TV.as(::Beta) = as𝕀
1717

18-
function logdensity(d::Beta{(:α, :β)}, x)
19-
return xlogy(d.α - 1, x) + xlog1py(d.β - 1, -x) - logbeta(d.α, d.β)
18+
function logdensity(d::Beta{(:α, :β), Tuple{A,B}}, x::X) where {A,B,X}
19+
if static_hasmethod(xlogy, Tuple{A,X}) && static_hasmethod(xlog1py, Tuple{B,X})
20+
return xlogy(d.α - 1, x) + xlog1py(d.β - 1, -x)
21+
else
22+
return (d.α - 1) * log(x) + (d.β - 1) * log1p(-x)
23+
end
24+
end
25+
26+
function basemeasure(d::Beta{(:α,:β)})
27+
inbounds(x) = 0 < x < 1
28+
const= 0.0
29+
varℓ() = - logbeta(d.α, d.β)
30+
base = Lebesgue(ℝ)
31+
FactoredBase(inbounds, constℓ, varℓ, base)
2032
end
2133

2234
Base.rand(rng::AbstractRNG, T::Type, μ::Beta) = rand(rng, Dists.Beta.α, μ.β))

src/parameterized/studentt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ function logdensity(d::StudentT{(:ν,)}, x)
4444
return+ 1) / (-2) * log1p(x^2 / ν)
4545
end
4646

47-
function basemeasure(d::StudentT{(:ν,)})
48-
inbounds(x) = true
47+
@inline function basemeasure(d::StudentT{(:ν,)})
48+
inbounds = Returns(true)
4949
const= 0.0
5050
varℓ() = loggamma((d.ν+1)/2) - loggamma(d.ν/2) - log* d.ν) / 2
5151
base = Lebesgue(ℝ)

0 commit comments

Comments
 (0)