Skip to content

Commit 8ca967c

Browse files
committed
2 parents fe43a92 + edec56d commit 8ca967c

28 files changed

+275
-313
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1313
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1414
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"
15+
PrettyPrinting = "54e16d92-306c-5ea0-a30b-337be88ac337"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
1718

@@ -23,6 +24,7 @@ KeywordCalls = "0.2"
2324
LogExpFunctions = "0.3"
2425
MLStyle = "0.4"
2526
MappedArrays = "0.4"
27+
PrettyPrinting = "0.3"
2628
Tricks = "0.1"
2729
julia = "1.3"
2830

src/MeasureBase.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ export AbstractMeasure
1616

1717
abstract type AbstractMeasure end
1818

19+
import PrettyPrinting
20+
21+
const Pretty = PrettyPrinting
22+
1923
sampletype::AbstractMeasure) = typeof(testvalue(μ))
2024

2125
# sampletype(μ::AbstractMeasure) = sampletype(basemeasure(μ))
@@ -37,7 +41,6 @@ Methods for computing density relative to other measures will be
3741
"""
3842
function logdensity end
3943

40-
4144
if VERSION < v"1.7.0-beta1.0"
4245
@eval begin
4346
struct Returns{T}
@@ -57,7 +60,6 @@ include("domains.jl")
5760
include("utils.jl")
5861
include("absolutecontinuity.jl")
5962
include("macros.jl")
60-
include("resettablerng.jl")
6163

6264
include("primitive.jl")
6365
include("primitives/counting.jl")

src/combinators/affine.jl

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,42 @@ using LinearAlgebra
44
par::NamedTuple{N,T}
55
end
66

7+
quoteof(f::AffineTransform) = :(AffineTransform($(quoteof(f.par))))
8+
79
params(f::AffineTransform) = getfield(f, :par)
810

911
@inline Base.getproperty(d::AffineTransform, s::Symbol) = getfield(getfield(d, :par), s)
1012

11-
Base.propertynames(d::AffineTransform{N}) where {N} = N
13+
Base.propertynames(d::AffineTransform{N}) where {N} = N
1214

13-
@inline Base.inv(f::AffineTransform{(:μ,:σ)}) = AffineTransform((μ = -(f.σ \ f.μ), ω = f.σ))
14-
@inline Base.inv(f::AffineTransform{(:μ,:ω)}) = AffineTransform((μ = - f.ω * f.μ, σ = f.ω))
15+
@inline Base.inv(f::AffineTransform{(:μ, :σ)}) =
16+
AffineTransform((μ = -(f.σ \ f.μ), ω = f.σ))
17+
@inline Base.inv(f::AffineTransform{(:μ, :ω)}) = AffineTransform((μ = -f.ω * f.μ, σ = f.ω))
1518
@inline Base.inv(f::AffineTransform{(:σ,)}) = AffineTransform((ω = f.σ,))
1619
@inline Base.inv(f::AffineTransform{(:ω,)}) = AffineTransform((σ = f.ω,))
1720
@inline Base.inv(f::AffineTransform{(:μ,)}) = AffineTransform((μ = -f.μ,))
1821

1922
# `size(f) == (m,n)` means `f : ℝⁿ → ℝᵐ`
20-
Base.size(f::AffineTransform{(:μ,:σ)}) = size(f.σ)
21-
Base.size(f::AffineTransform{(:μ,:ω)}) = size(f.ω)
22-
Base.size(f::AffineTransform{(:σ,)}) = size(f.σ)
23-
Base.size(f::AffineTransform{(:ω,)}) = size(f.ω)
23+
Base.size(f::AffineTransform{(:μ, :σ)}) = size(f.σ)
24+
Base.size(f::AffineTransform{(:μ, :ω)}) = size(f.ω)
25+
Base.size(f::AffineTransform{(:σ,)}) = size(f.σ)
26+
Base.size(f::AffineTransform{(:ω,)}) = size(f.ω)
2427

2528
function Base.size(f::AffineTransform{(:μ,)})
2629
(n,) = size(f.μ)
27-
return (n,n)
30+
return (n, n)
2831
end
2932

3033
Base.size(f::AffineTransform, n::Int) = @inbounds size(f)[n]
3134

3235
(f::AffineTransform{(:μ,)})(x) = x + f.μ
3336
(f::AffineTransform{(:σ,)})(x) = f.σ * x
3437
(f::AffineTransform{(:ω,)})(x) = f.ω \ x
35-
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
36-
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
38+
(f::AffineTransform{(:μ, :σ)})(x) = f.σ * x + f.μ
39+
(f::AffineTransform{(:μ, :ω)})(x) = f.ω \ x + f.μ
3740

3841
rowsize(x) = ()
39-
rowsize(x::AbstractArray) = (size(x,1),)
42+
rowsize(x::AbstractArray) = (size(x, 1),)
4043

4144
function rowsize(f::AffineTransform)
4245
size_f = size(f)
@@ -46,7 +49,7 @@ function rowsize(f::AffineTransform)
4649
end
4750

4851
colsize(x) = ()
49-
colsize(x::AbstractArray) = (size(x,2),)
52+
colsize(x::AbstractArray) = (size(x, 2),)
5053

5154
function colsize(f::AffineTransform)
5255
size_f = size(f)
@@ -65,7 +68,7 @@ end
6568
return x
6669
end
6770

68-
@inline function apply!(x, f::AffineTransform{(:ω,), Tuple{F}}, z) where {F<:Factorization}
71+
@inline function apply!(x, f::AffineTransform{(:ω,),Tuple{F}}, z) where {F<:Factorization}
6972
ldiv!(x, f.ω, z)
7073
return x
7174
end
@@ -75,20 +78,20 @@ end
7578
return x
7679
end
7780

78-
@inline function apply!(x, f::AffineTransform{(:μ,:σ)}, z)
81+
@inline function apply!(x, f::AffineTransform{(:μ, :σ)}, z)
7982
apply!(x, AffineTransform((σ = f.σ,)), z)
8083
apply!(x, AffineTransform((μ = f.μ,)), x)
8184
return x
8285
end
8386

84-
@inline function apply!(x, f::AffineTransform{(:μ,:ω)}, z)
87+
@inline function apply!(x, f::AffineTransform{(:μ, :ω)}, z)
8588
apply!(x, AffineTransform((ω = f.ω,)), z)
8689
apply!(x, AffineTransform((μ = f.μ,)), x)
8790
return x
8891
end
8992

90-
function logjac(x::AbstractMatrix)
91-
(m,n) = size(x)
93+
function logjac(x::AbstractMatrix)
94+
(m, n) = size(x)
9295
m == n && return first(logabsdet(x))
9396

9497
# Equivalent to sum(log, svdvals(x)), but much faster
@@ -99,8 +102,8 @@ end
99102
logjac(x::Number) = log(abs(x))
100103

101104
# TODO: `log` doesn't work for the multivariate case, we need the log absolute determinant
102-
logjac(f::AffineTransform{(:μ,:σ)}) = logjac(f.σ)
103-
logjac(f::AffineTransform{(:μ,:ω)}) = -logjac(f.ω)
105+
logjac(f::AffineTransform{(:μ, :σ)}) = logjac(f.σ)
106+
logjac(f::AffineTransform{(:μ, :ω)}) = -logjac(f.ω)
104107
logjac(f::AffineTransform{(:σ,)}) = logjac(f.σ)
105108
logjac(f::AffineTransform{(:ω,)}) = -logjac(f.ω)
106109
logjac(f::AffineTransform{(:μ,)}) = 0.0
@@ -130,16 +133,16 @@ function params(μ::Affine)
130133
return merge(nt1, nt2)
131134
end
132135

133-
function paramnames(::Type{A}) where {N,M, A<:Affine{N,M}}
136+
function paramnames(::Type{A}) where {N,M,A<:Affine{N,M}}
134137
tuple(union(N, paramnames(M))...)
135138
end
136139

137-
Base.propertynames(d::Affine{N}) where {N} = N (:parent,:f)
140+
Base.propertynames(d::Affine{N}) where {N} = N (:parent, :f)
138141

139-
@inline function Base.getproperty(d::Affine, s::Symbol)
142+
@inline function Base.getproperty(d::Affine, s::Symbol)
140143
if s === :parent
141144
return getfield(d, :parent)
142-
elseif s === :f
145+
elseif s === :f
143146
return getfield(d, :f)
144147
else
145148
return getproperty(getfield(d, :f), s)
@@ -166,18 +169,18 @@ end
166169

167170
function logdensity(d::Affine{(:μ,)}, x)
168171
z = x - d.μ
169-
logdensity(d.parent, z)
172+
logdensity(d.parent, z)
170173
end
171174

172-
function logdensity(d::Affine{(:μ,:σ)}, x)
175+
function logdensity(d::Affine{(:μ, :σ)}, x)
173176
z = d.σ \ (x - d.μ)
174-
logdensity(d.parent, z)
177+
logdensity(d.parent, z)
175178
end
176179

177-
function logdensity(d::Affine{(:μ,:ω)}, x)
180+
function logdensity(d::Affine{(:μ, :ω)}, x)
178181
z = d.ω * (x - d.μ)
179182
logdensity(d.parent, z)
180-
end
183+
end
181184

182185
# # logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
183186
# @inline function logdensity(d::Affine{(:μ,:σ), P, Tuple{V,M}}, x) where {P, V<:AbstractVector, M<:AbstractMatrix}
@@ -190,7 +193,7 @@ end
190193
# end
191194
# sum(zⱼ -> logdensity(d.parent, zⱼ), z)
192195
# end
193-
196+
194197
# # logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
195198
# @inline function logdensity(d::Affine{(:μ,:ω), P,Tuple{V,M}}, x) where {P,V<:AbstractVector, M<:AbstractMatrix}
196199
# z = x - d.μ
@@ -202,31 +205,35 @@ basemeasure(d::Affine) = affine(getfield(d, :f), basemeasure(d.parent))
202205

203206
# We can't do this until we know we're working with Lebesgue measure, since for
204207
# example it wouldn't make sense to apply a log-Jacobian to a point measure
205-
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = weightedmeasure(-logjac(d), d.parent)
208+
basemeasure(d::Affine{N,L}) where {N,L<:Lebesgue} = weightedmeasure(-logjac(d), d.parent)
206209

207-
function basemeasure(d::Affine{N,M}) where {N,L<:Lebesgue, M<:ProductMeasure{Returns{L}}}
210+
function basemeasure(d::Affine{N,M}) where {N,L<:Lebesgue,M<:ProductMeasure{Returns{L}}}
208211
weightedmeasure(-logjac(d), d.parent)
209212
end
210213

211214
logjac(d::Affine) = logjac(getfield(d, :f))
212215

213-
function Random.rand!(rng::Random.AbstractRNG, d::Affine, x::AbstractVector{T}, z=Vector{T}(undef, size(getfield(d,:f),2))) where {T}
216+
function Random.rand!(
217+
rng::Random.AbstractRNG,
218+
d::Affine,
219+
x::AbstractVector{T},
220+
z = Vector{T}(undef, size(getfield(d, :f), 2))
221+
) where {T}
214222
rand!(rng, parent(d), z)
215223
f = getfield(d, :f)
216224
apply!(x, f, z)
217225
return x
218226
end
219227

220-
221228
# function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
222229
# f = getfield(d, :f)
223230
# z = rand(rng, T, parent(d))
224231
# apply!(x, f, z)
225232
# return z
226233
# end
227234

228-
supportdim(nt::NamedTuple{(:μ,:σ)}) = colsize(nt.σ)
229-
supportdim(nt::NamedTuple{(:μ,:ω)}) = rowsize(nt.ω)
230-
supportdim(nt::NamedTuple{(:σ,)}) = colsize(nt.σ)
231-
supportdim(nt::NamedTuple{(:ω,)}) = rowsize(nt.ω)
232-
supportdim(nt::NamedTuple{(:μ,)}) = size(nt.μ)
235+
supportdim(nt::NamedTuple{(:μ, :σ)}) = colsize(nt.σ)
236+
supportdim(nt::NamedTuple{(:μ, :ω)}) = rowsize(nt.ω)
237+
supportdim(nt::NamedTuple{(:σ,)}) = colsize(nt.σ)
238+
supportdim(nt::NamedTuple{(:ω,)}) = rowsize(nt.ω)
239+
supportdim(nt::NamedTuple{(:μ,)}) = size(nt.μ)

src/combinators/factoredbase.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
export FactoredBase
22

33
struct FactoredBase{R,C,V,B} <: AbstractMeasure
4-
inbounds :: R
5-
const :: C
6-
varℓ :: V
7-
base :: B
4+
inbounds::R
5+
const::C
6+
varℓ::V
7+
base::B
88
end
99

1010
@inline function logdensity(d::FactoredBase, x)

src/combinators/for.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ export For
33
using Random
44
import Base
55

6-
76
"""
87
For(f, base...)
98
@@ -79,10 +78,9 @@ For(f, dims...) = productmeasure(i -> f(i...), zip(dims...))
7978

8079
For(f, inds::AbstractArray) = productmeasure(f, inds)
8180

82-
For(f, n::Int) = productmeasure(f, 1:n)
81+
For(f, n::Int) = productmeasure(f, Base.OneTo(n))
8382
For(f, dims::Int...) = productmeasure(i -> f(Tuple(i)...), CartesianIndices(dims))
8483

85-
8684
function Base.eltype(d::ProductMeasure{F,I}) where {F,I<:AbstractArray}
8785
return eltype(d.f(first(d.pars)))
8886
end

src/combinators/likelihood.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,16 @@ struct Likelihood{F,S,X}
106106
x::X
107107
end
108108

109-
Likelihood::AbstractMeasure, x) = Likelihood(kernel(μ),x)
109+
Likelihood::AbstractMeasure, x) = Likelihood(kernel(μ), x)
110110

111-
Likelihood(::Type{M}, x) where {M<:AbstractMeasure} = Likelihood(kernel(M),x)
111+
Likelihood(::Type{M}, x) where {M<:AbstractMeasure} = Likelihood(kernel(M), x)
112112

113-
function Base.show(io::IO, ℓ::Likelihood)
113+
function Base.show(io::IO, ℓ::Likelihood)
114114
io = IOContext(io, :compact => true)
115115
k, x =.k, ℓ.x
116-
print(io, "Likelihood(",k,", ", x, ")")
116+
print(io, "Likelihood(", k, ", ", x, ")")
117117
end
118118

119-
120-
function logdensity(ℓ::Likelihood, p)
119+
function logdensity(ℓ::Likelihood, p)
121120
return logdensity(ℓ.k(p), ℓ.x)
122121
end

src/combinators/mapsto.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ end
55

66
export , mapsto
77

8-
mapsto(x,y) = x y
8+
mapsto(x, y) = x y
99

10-
(x::X,y::Y) where {X,Y} = MapsTo{X,Y}(x,y)
10+
(x::X, y::Y) where {X,Y} = MapsTo{X,Y}(x, y)
1111

1212
Base.first(t::MapsTo) = t.x
1313
Base.last(t::MapsTo) = t.y

src/combinators/pointwise.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export
22

33
@concrete terse struct PointwiseProductMeasure{T} <: AbstractMeasure
4-
data :: T
4+
data::T
55
end
66

77
Base.size::PointwiseProductMeasure) = size.data)
@@ -31,8 +31,8 @@ Base.length(m::PointwiseProductMeasure{T}) where {T} = length(m.data)
3131
sum((logdensity(dⱼ, x) for dⱼ in d.data))
3232
end
3333

34-
function sampletype(d::PointwiseProductMeasure)
34+
function sampletype(d::PointwiseProductMeasure)
3535
@inbounds sampletype(first(d.data))
3636
end
3737

38-
basemeasure(d::PointwiseProductMeasure) = @inbounds basemeasure(first(d.data))
38+
basemeasure(d::PointwiseProductMeasure) = @inbounds basemeasure(first(d.data))

src/combinators/power.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,24 @@ const PowerMeasure{F,S,T,N,A} = ProductMeasure{F,S,Fill{T,N,A}}
3434
Base.:^::AbstractMeasure, ::Tuple{}) = μ
3535

3636
function Base.:^::AbstractMeasure, dims::Integer...)
37-
return μ ^ dims
37+
return μ^dims
3838
end
3939

4040
function Base.:^::M, dims::NTuple{N,I}) where {M<:AbstractMeasure,N,I<:Integer}
4141
powermeasure(μ, dims)
4242
end
4343

4444
# Same as PowerMeasure
45-
function Base.show(io::IO, d::ProductMeasure{Returns{T},I,C}) where {T,I,C<:CartesianIndices}
46-
io = IOContext(io, :compact => true)
47-
print(io, d.f.f.value, " ^ ", size(d.pars))
45+
function Pretty.tile(d::ProductMeasure{Returns{T},I,C}) where {T,I,C<:CartesianIndices}
46+
Pretty.pair_layout(Pretty.tile(d.f.f.value), Pretty.tile(size(d.pars)); sep = " ^ ")
4847
end
4948

50-
function Base.show(io::IO, d::ProductMeasure{R,I,V}) where {R<:Returns,I,V<:AbstractVector}
51-
io = IOContext(io, :compact => true)
52-
print(io, d.f.f.value, " ^ ", length(d.pars))
49+
function Pretty.tile(d::ProductMeasure{R,I,V}) where {R<:Returns,I,V<:AbstractVector}
50+
Pretty.pair_layout(Pretty.tile(d.f.f.value), Pretty.tile(length(d.pars)); sep = " ^ ")
5351
end
5452

5553
# sampletype(d::PowerMeasure{M,N}) where {M,N} = @inbounds Array{sampletype(first(marginals(d))), N}
5654

57-
58-
5955
params(d::ProductMeasure{F,S,<:Fill}) where {F,S} = params(first(marginals(d)))
6056

6157
params(::Type{P}) where {F,S,P<:ProductMeasure{F,S,<:Fill}} = params(D)
@@ -69,16 +65,16 @@ end
6965

7066
# Same as PowerMeasure
7167
@inline function _basemeasure(d::ProductMeasure{F,S,<:Fill}, b) where {F,S}
72-
b ^ size(d.pars)
68+
b^size(d.pars)
7369
end
7470

7571
# Same as PowerMeasure
7672
@inline function _basemeasure(d::ProductMeasure{F,S,<:Fill}, b::FactoredBase) where {F,S}
7773
n = length(d.pars)
78-
inbounds(x) = all(xj -> b.inbounds(xj), x)
74+
inbounds(x) = all(b.inbounds, x)
7975
const= n * b.const
8076
varℓ() = n * b.varℓ()
81-
base = b.base ^ size(d.pars)
77+
base = b.base^size(d.pars)
8278
FactoredBase(inbounds, constℓ, varℓ, base)
8379
end
8480

0 commit comments

Comments
 (0)