Skip to content

Commit 1f9d99c

Browse files
authored
Affine (#23)
# `basekernel` For `k::Kernel`, `basekernel` has the property ```julia basemeasure(k(p)) == basekernel(k)(p) ``` There are many cases for which ```julia basekernel(k) isa Returns ``` in which case `p` doesn't matter. This can make things much more efficient, especially for large product measures. # Smart Constructors Many functions are now in terms of "smart constructors". These have lower-case names, for example `kernel` as opposed to `Kernel`. There's also a new `TupleProductMeasure` to handle cases like `Normal() ⊗ Lebesgue(ℝ)`. A given smart constructor should have a very small number of methods that call the struct itself; the remainder should all make a small reduction step, usually calling another smart constructor. # Changes to ProductMeasure `ProductMeasure` is now in terms of `Kernel`. # MapsTo When computing the logpdf of an affine measure, we currently have to recompute the pullback many times. As a step toward improving this, this PR introduces a `MapsTo` combinator that works like ```julia julia> typeof([1,2] ↦ 3) MeasureBase.MapsTo{Vector{Int64}, Int64} ``` The idea is that we can carry this around until the logpdf computation is done, and avoid recomputation in this way. ----------------------------------- * WIP * have logpdf use == instead of === (at least for now) * move kernel stuff to smart-constructors * some refactoring * fixing stuff * typo * AbstractProductMeasure * get tests passing * shorten basemeasure(d::ProductMeasure) * basekernel docstring * correct dispatch problem * update comment * thunks WIP * bump version
1 parent 0a69cef commit 1f9d99c

File tree

12 files changed

+247
-93
lines changed

12 files changed

+247
-93
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"

src/MeasureBase.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const logtwo = log(2.0)
55
using Random
66
import Random: rand!
77

8+
using FillArrays
89
using ConcreteStructs
910
using MLStyle
1011

@@ -21,6 +22,7 @@ sampletype(μ::AbstractMeasure) = typeof(testvalue(μ))
2122

2223
export logdensity
2324
export basemeasure
25+
export basekernel
2426

2527
using LogExpFunctions: logsumexp
2628

@@ -46,12 +48,14 @@ if VERSION < v"1.7.0-beta1.0"
4648
end
4749
end
4850

51+
include("kernel.jl")
52+
include("parameterized.jl")
53+
include("combinators/mapsto.jl")
4954
include("combinators/half.jl")
5055
include("exp.jl")
5156
include("domains.jl")
5257
include("utils.jl")
5358
include("absolutecontinuity.jl")
54-
include("parameterized.jl")
5559
include("macros.jl")
5660
include("resettablerng.jl")
5761

@@ -69,7 +73,6 @@ include("combinators/for.jl")
6973
include("combinators/power.jl")
7074
include("combinators/affine.jl")
7175
include("combinators/spikemixture.jl")
72-
include("kernel.jl")
7376
include("combinators/likelihood.jl")
7477
include("combinators/pointwise.jl")
7578
include("combinators/restricted.jl")

src/combinators/affine.jl

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ Base.size(f::AffineTransform, n::Int) = @inbounds size(f)[n]
3434
(f::AffineTransform{(:ω,)})(x) = f.ω \ x
3535
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
3636
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
37+
1
38+
rowsize(x) = ()
39+
rowsize(x::AbstractArray) = (size(x,1),)
40+
41+
colsize(x) = ()
42+
colsize(x::AbstractArray) = (size(x,2),)
3743

3844
@inline function apply!(x, f::AffineTransform{(:μ,)}, z)
3945
x .= z .+ f.μ
@@ -130,14 +136,37 @@ Base.size(d::Affine) = size(d.μ)
130136
Base.size(d::Affine{(:σ,)}) = (size(d.σ, 1),)
131137
Base.size(d::Affine{(:ω,)}) = (size(d.ω, 2),)
132138

133-
logdensity(d::Affine{(:σ,)}, x) = logdensity(d.parent, d.σ \ x)
134-
logdensity(d::Affine{(:ω,)}, x) = logdensity(d.parent, d.ω * x)
135-
logdensity(d::Affine{(:μ,)}, x) = logdensity(d.parent, x - d.μ)
136-
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
137-
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
139+
logdensity(d::Affine, x::MapsTo) = logdensity(d.parent, x.x)
140+
141+
function logdensity(d::Affine{(:σ,)}, x)
142+
z = d.σ \ x
143+
# @show z
144+
# println()
145+
logdensity(d.parent, z)
146+
end
147+
148+
function logdensity(d::Affine{(:ω,)}, x)
149+
z = d.ω * x
150+
logdensity(d.parent, z)
151+
end
152+
153+
function logdensity(d::Affine{(:μ,)}, x)
154+
z = x - d.μ
155+
logdensity(d.parent, z)
156+
end
157+
158+
function logdensity(d::Affine{(:μ,:σ)}, x)
159+
z = d.σ \ (x - d.μ)
160+
logdensity(d.parent, z)
161+
end
162+
163+
function logdensity(d::Affine{(:μ,:ω)}, x)
164+
z = d.ω * (x - d.μ)
165+
logdensity(d.parent, z)
166+
end
138167

139168
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
140-
@inline function logdensity(d::Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}}, x)
169+
@inline function logdensity(d::Affine{(:μ,:σ), P, Tuple{V,M}}, x) where {P, V<:AbstractVector, M<:AbstractMatrix}
141170
z = x - d.μ
142171
σ = d.σ
143172
if σ isa Factorization
@@ -149,7 +178,7 @@ logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
149178
end
150179

151180
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
152-
@inline function logdensity(d::Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}}, x)
181+
@inline function logdensity(d::Affine{(:μ,:ω), P,Tuple{V,M}}, x) where {P,V<:AbstractVector, M<:AbstractMatrix}
153182
z = x - d.μ
154183
lmul!(d.ω, z)
155184
logdensity(d.parent, z)
@@ -161,7 +190,7 @@ basemeasure(d::Affine) = affine(getfield(d, :f), basemeasure(d.parent))
161190
# example it wouldn't make sense to apply a log-Jacobian to a point measure
162191
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = weightedmeasure(-logjac(d), d.parent)
163192

164-
function basemeasure(d::Affine{N,L}) where {N, L<:PowerMeasure{typeof(identity), <:Lebesgue}}
193+
function basemeasure(d::Affine{N,L}) where {N, L<:PowerMeasure{typeof(identity), typeof(identity), <:Lebesgue}}
165194
weightedmeasure(-logjac(d), d.parent)
166195
end
167196

src/combinators/for.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ julia> For(eachrow(rand(4,2))) do x Normal(x[1], x[2]) end |> marginals |> colle
7575
```
7676
7777
"""
78-
For(f, dims...) = ProductMeasure(i -> f(i...), zip(dims...))
78+
For(f, dims...) = productmeasure(i -> f(i...), zip(dims...))
7979

80-
For(f, inds::AbstractArray) = ProductMeasure(f, inds)
80+
For(f, inds::AbstractArray) = productmeasure(f, inds)
8181

82-
For(f, n::Int) = ProductMeasure(f, 1:n)
83-
For(f, dims::Int...) = ProductMeasure(i -> f(Tuple(i)...), CartesianIndices(dims))
82+
For(f, n::Int) = productmeasure(f, 1:n)
83+
For(f, dims::Int...) = productmeasure(i -> f(Tuple(i)...), CartesianIndices(dims))
8484

8585

8686
function Base.eltype(d::ProductMeasure{F,I}) where {F,I<:AbstractArray}

src/combinators/mapsto.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
struct MapsTo{X,Y}
2+
x::X
3+
y::Y
4+
end
5+
6+
export , mapsto
7+
8+
mapsto(x,y) = x y
9+
10+
(x::X,y::Y) where {X,Y} = MapsTo{X,Y}(x,y)
11+
12+
Base.first(t::MapsTo) = t.x
13+
Base.last(t::MapsTo) = t.y
14+
15+
Base.Pair(t::MapsTo) = t.x => t.y
16+
17+
Base.show(io::IO, t::MapsTo) = print(t.x, "", t.y)
18+
19+
logdensity(d, t::MapsTo) = logdensity(d, t.y)

src/combinators/power.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ using FillArrays: Fill
2929

3030
export PowerMeasure
3131

32-
const PowerMeasure{F,T,N,A} = ProductMeasure{F,Fill{T,N,A}}
32+
const PowerMeasure{F,S,T,N,A} = ProductMeasure{F,S,Fill{T,N,A}}
33+
34+
Base.:^::AbstractMeasure, ::Tuple{}) = μ
3335

3436
function Base.:^::AbstractMeasure, dims::Integer...)
3537
return μ ^ dims
@@ -40,13 +42,13 @@ function Base.:^(μ::M, dims::NTuple{N,I}) where {M<:AbstractMeasure,N,I<:Intege
4042
end
4143

4244
# Same as PowerMeasure
43-
function Base.show(io::IO, d::ProductMeasure{F,<:Fill}) where {F}
45+
function Base.show(io::IO, d::ProductMeasure{F,S,<:Fill}) where {F,S}
4446
io = IOContext(io, :compact => true)
4547
print(io, d.f(first(d.pars)), " ^ ", size(d.pars))
4648
end
4749

4850
# Same as PowerMeasure{F,T,1} where {F,T}
49-
function Base.show(io::IO, d::ProductMeasure{F,Fill{T,1,A}}) where {F,T,A}
51+
function Base.show(io::IO, d::ProductMeasure{F,S,Fill{T,1,A}}) where {F,S,T,A}
5052
io = IOContext(io, :compact => true)
5153
print(io, d.f(first(d.pars)), " ^ ", size(d.pars)[1])
5254
end
@@ -55,20 +57,24 @@ end
5557

5658

5759

58-
params(d::ProductMeasure{F,<:Fill}) where {F} = params(first(marginals(d)))
60+
params(d::ProductMeasure{F,S,<:Fill}) where {F,S} = params(first(marginals(d)))
5961

60-
params(::Type{P}) where {F,P<:ProductMeasure{F,<:Fill}} = params(D)
62+
params(::Type{P}) where {F,S,P<:ProductMeasure{F,S,<:Fill}} = params(D)
6163

6264
# basemeasure(μ::PowerMeasure) = @inbounds basemeasure(first(μ.data))^size(μ.data)
6365

6466
# Same as PowerMeasure
65-
@inline basemeasure(d::ProductMeasure{F,<:Fill}) where {F}= _basemeasure(d, (basemeasure(d.f(first(d.pars)))))
67+
@inline function basemeasure(d::ProductMeasure{F,S,<:Fill}) where {F,S}
68+
_basemeasure(d, (basemeasure(d.f(first(d.pars)))))
69+
end
6670

6771
# Same as PowerMeasure
68-
@inline _basemeasure(d::ProductMeasure{F,<:Fill}, b) where {F} = b ^ size(d.pars)
72+
@inline function _basemeasure(d::ProductMeasure{F,S,<:Fill}, b) where {F,S}
73+
b ^ size(d.pars)
74+
end
6975

7076
# Same as PowerMeasure
71-
@inline function _basemeasure(d::ProductMeasure{F,<:Fill}, b::FactoredBase) where {F}
77+
@inline function _basemeasure(d::ProductMeasure{F,S,<:Fill}, b::FactoredBase) where {F,S}
7278
n = length(d.pars)
7379
inbounds(x) = all(xj -> b.inbounds(xj), x)
7480
const= n * b.const
@@ -78,7 +84,7 @@ params(::Type{P}) where {F,P<:ProductMeasure{F,<:Fill}} = params(D)
7884
end
7985

8086
# Same as PowerMeasure
81-
@inline function logdensity(d::ProductMeasure{F,<:Fill}, x) where {F}
87+
@inline function logdensity(d::ProductMeasure{F,S,<:Fill}, x) where {F,S}
8288
d1 = d.f(first(d.pars))
8389
sum(xj -> logdensity(d1, xj), x)
8490
end

src/combinators/product.jl

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,54 @@ export ProductMeasure
22

33
using MappedArrays
44
using Base: @propagate_inbounds
5+
import Base
56
using FillArrays
67

7-
struct ProductMeasure{F,I} <: AbstractMeasure
8-
f::F
8+
abstract type AbstractProductMeasure <: AbstractMeasure end
9+
10+
11+
struct ProductMeasure{F,S,I} <: AbstractProductMeasure
12+
f::Kernel{F,S}
913
pars::I
1014
end
1115

16+
17+
# TODO: Test for equality without traversal, probably by first converting to a
18+
# canonical form
19+
function Base.:(==)(a::ProductMeasure, b::ProductMeasure)
20+
all(zip(a.pars, b.pars)) do (aᵢ, bᵢ)
21+
a.f(aᵢ) == b.f(bᵢ)
22+
end
23+
end
24+
1225
Base.size::ProductMeasure) = size(marginals(μ))
1326

1427
Base.length(m::ProductMeasure{T}) where {T} = length(marginals(μ))
1528

16-
# TODO: Pull weights outside
17-
basemeasure(d::ProductMeasure) = ProductMeasure(basemeasure d.f, d.pars)
18-
basemeasure(d::ProductMeasure{typeof(identity)}) = ProductMeasure(identity, map(basemeasure, d.pars))
19-
basemeasure(d::ProductMeasure{typeof(identity), <:FillArrays.Fill}) = ProductMeasure(identity, map(basemeasure, d.pars))
29+
basemeasure(d::ProductMeasure) = productmeasure(basekernel(d.f), d.pars)
30+
31+
# TODO: Do we need these methods?
32+
# basemeasure(d::ProductMeasure) = ProductMeasure(basemeasure ∘ d.f, d.pars)
33+
# basemeasure(d::ProductMeasure{typeof(identity)}) = ProductMeasure(identity, map(basemeasure, d.pars))
34+
# basemeasure(d::ProductMeasure{typeof(identity), <:FillArrays.Fill}) = ProductMeasure(identity, map(basemeasure, d.pars))
2035

2136
export marginals
2237

23-
function marginals(d::ProductMeasure{F,I}) where {F,I}
38+
function marginals(d::ProductMeasure{F,S,I}) where {F,S,I}
2439
_marginals(d, isiterable(I))
2540
end
2641

27-
function _marginals(d::ProductMeasure{F,I}, ::Iterable) where {F,I}
42+
function _marginals(d::ProductMeasure, ::Iterable)
2843
return (d.f(i) for i in d.pars)
2944
end
3045

31-
function _marginals(d::ProductMeasure{F,I}, ::NonIterable) where {F,I}
46+
function _marginals(d::ProductMeasure{F,S,I}, ::NonIterable) where {F,S,I}
3247
error("Type $I is not iterable. Add an `iterate` or `marginals` method to fix.")
3348
end
3449

3550
testvalue(d::ProductMeasure) = map(testvalue, marginals(d))
3651

37-
function Base.show(io::IO, μ::ProductMeasure{NamedTuple{N,T}}) where {N,T}
52+
function Base.show(io::IO, μ::ProductMeasure{F,S,NamedTuple{N,T}}) where {F,S,N,T}
3853
io = IOContext(io, :compact => true)
3954
print(io, "Product(".data, ")")
4055
end
@@ -55,21 +70,25 @@ end
5570
###############################################################################
5671
# I <: Tuple
5772

73+
struct TupleProductMeasure{T} <: AbstractProductMeasure
74+
pars::T
75+
end
76+
5877
export
59-
(μs::AbstractMeasure...) = ProductMeasure(identity, μs)
78+
(μs::AbstractMeasure...) = productmeasure(μs)
6079

61-
marginals(d::ProductMeasure{F,T}) where {F, T<:Tuple} = map(d.f, d.pars)
80+
marginals(d::TupleProductMeasure{T}) where {F, T<:Tuple} = d.pars
6281

63-
function Base.show(io::IO, μ::ProductMeasure{F,T}) where {F,T <: Tuple}
82+
function Base.show(io::IO, μ::TupleProductMeasure{T}) where {F,T <: Tuple}
6483
io = IOContext(io, :compact => true)
6584
print(io, join(string.(marginals(μ)), ""))
6685
end
6786

68-
@inline function logdensity(d::ProductMeasure{F,T}, x::Tuple) where {F,T<:Tuple}
69-
mapreduce(logdensity, +, d.f.(d.pars), x)
87+
@inline function logdensity(d::TupleProductMeasure, x::Tuple) where {T<:Tuple}
88+
mapreduce(logdensity, +, d.pars, x)
7089
end
7190

72-
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:Tuple}
91+
function Base.rand(rng::AbstractRNG, ::Type{T}, d::TupleProductMeasure) where {T}
7392
rand.(d.pars)
7493
end
7594

@@ -93,7 +112,7 @@ end
93112
###############################################################################
94113
# I <: CartesianIndices
95114

96-
function Base.show(io::IO, d::ProductMeasure{F,I}) where {F, I<:CartesianIndices}
115+
function Base.show(io::IO, d::ProductMeasure{F,S,I}) where {F, S, I<:CartesianIndices}
97116
io = IOContext(io, :compact => true)
98117
print(io, "For(")
99118
print(io, d.f, ", ")
@@ -102,7 +121,7 @@ function Base.show(io::IO, d::ProductMeasure{F,I}) where {F, I<:CartesianIndices
102121
end
103122

104123

105-
# function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:CartesianIndices}
124+
# function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,S,I}) where {T,F,I<:CartesianIndices}
106125

107126
# end
108127

@@ -114,12 +133,12 @@ export rand!
114133
using Random: rand!, GLOBAL_RNG, AbstractRNG
115134

116135

117-
function logdensity(d::ProductMeasure{F,I}, x) where {F, I<:Base.Generator}
136+
function logdensity(d::ProductMeasure{F,S,I}, x) where {F, S, I<:Base.Generator}
118137
sum((logdensity(dj, xj) for (dj, xj) in zip(marginals(d), x)))
119138
end
120139

121140

122-
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:Base.Generator}
141+
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,S,I}) where {T,F,S,I<:Base.Generator}
123142
mar = marginals(d)
124143
elT = typeof(rand(rng, T, first(mar)))
125144

@@ -183,8 +202,8 @@ end
183202
# μ.data
184203
# end
185204

186-
function ConstructionBase.constructorof(::Type{P}) where {F,I,P <: ProductMeasure{F,I}}
187-
p -> ProductMeasure(d.f, p)
205+
function ConstructionBase.constructorof(::Type{P}) where {F,S,I,P <: ProductMeasure{F,S,I}}
206+
p -> productmeasure(d.f, p)
188207
end
189208

190209
# function Accessors.set(d::ProductMeasure{N}, ::typeof(params), p) where {N}
@@ -201,3 +220,12 @@ end
201220
# logdensity(μ_ν_x...)
202221
# end
203222
# end
223+
224+
function kernelfactor::ProductMeasure{F,S,<:Fill}) where {F,S}
225+
k = kernel(first(marginals(μ)))
226+
(p -> k.f(p)^size(μ), k.ops)
227+
end
228+
229+
function kernelfactor::ProductMeasure{F,S,A}) where {F,S,A<:AbstractArray}
230+
(p -> set.(marginals(μ), params, p), μ.pars)
231+
end

0 commit comments

Comments
 (0)