Skip to content

Commit de50b0a

Browse files
authored
Update MvNormal (#160)
* update MvNormal to use Affine * typos * inline _logpdf * make things easily dubuggable * update tests * AND ANOTHER * update tests, but MANY NEED FIXING * LogExpFunctions 0.3.3 (for xlog1py) * working on tests * logpdf from MeasureBase * logpdf test * Symbolic-friendly StudentT * drop redudant code * basekernel * bugfix * Returns{T} <: Function * fixing stuff * asparams(::Affine) * debugging, and tests * use Measurebase.Returns (usually from Base) instead of making another * adjust for Productmesaure updates * Some refactoring * MvNormal updates * compat `replace` * oops typo * Actually, let's just make this `replace` stuff more direct * bump version
1 parent 0890e24 commit de50b0a

File tree

13 files changed

+227
-117
lines changed

13 files changed

+227
-117
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.12.0"
4+
version = "0.13.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -37,7 +37,7 @@ DynamicIterators = "0.4"
3737
FillArrays = "0.12"
3838
InfiniteArrays = "0.11"
3939
KeywordCalls = "0.2"
40-
LogExpFunctions = "0.3"
40+
LogExpFunctions = "0.3.3"
4141
MLStyle = "0.4"
4242
MacroTools = "0.5"
4343
MappedArrays = "0.4"

src/MeasureTheory.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ export CountingMeasure
2828
export TrivialMeasure
2929
export Likelihood
3030
export testvalue
31+
export basekernel
3132

3233
using InfiniteArrays
3334
using ConcreteStructs
@@ -39,7 +40,12 @@ using StatsFuns
3940
using SpecialFunctions
4041
using LogExpFunctions
4142

43+
import NamedTupleTools
44+
4245
import MeasureBase: testvalue, logdensity, density, basemeasure, kernel, params, ∫
46+
import MeasureBase: affine, supportdim
47+
48+
import Base: rand
4349

4450
using Reexport
4551
@reexport using MeasureBase
@@ -53,38 +59,22 @@ export as
5359
export Affine
5460
export AffineTransform
5561

56-
if VERSION < v"1.7.0-beta1.0"
57-
@eval begin
58-
struct Returns{T}
59-
value::T
60-
end
61-
62-
(f::Returns)(x) = f.value
63-
end
64-
end
62+
using MeasureBase: Returns
6563

6664
sampletype::AbstractMeasure) = typeof(testvalue(μ))
6765

6866
# sampletype(μ::AbstractMeasure) = sampletype(basemeasure(μ))
6967

70-
import Distributions: pdf, logpdf
7168

7269

73-
export pdf, logpdf
74-
75-
function logpdf(d::AbstractMeasure, x)
76-
_logpdf(d, x, logdensity(d,x))
77-
end
7870

79-
function _logpdf(d::AbstractMeasure, x, acc::Real)
80-
β = basemeasure(d)
81-
d === β && return acc
71+
import Distributions: logpdf, pdf
8272

83-
_logpdf(β, x, acc + logdensity(β, x))
84-
end
73+
export pdf, logpdf
8574

75+
Distributions.logpdf(d::AbstractMeasure, x) = MeasureBase.logpdf(d, x)
8676

87-
pdf(d::AbstractMeasure, x) = exp(logpdf(d, x))
77+
Distributions.pdf(d::AbstractMeasure, x) = exp(Dists.logpdf(d, x))
8878

8979
"""
9080
logdensity(μ::AbstractMeasure [, ν::AbstractMeasure], x::X)
@@ -95,16 +85,26 @@ is understood to be `basemeasure(μ)`.
9585
"""
9686
function logdensity end
9787

88+
89+
const AFFINEPARS = [
90+
(,)
91+
(,)
92+
(,)
93+
(,)
94+
(,)
95+
]
96+
97+
9898
include("const.jl")
9999
# include("traits.jl")
100100
include("parameterized.jl")
101101
# include("resettablerng.jl")
102102

103+
include("combinators/affine.jl")
103104
include("combinators/weighted.jl")
104105
include("combinators/product.jl")
105106
include("combinators/transforms.jl")
106107
include("combinators/chain.jl")
107-
# include("combinators/basemeasure.jl")
108108

109109
include("distributions.jl")
110110

src/combinators/affine.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
asparams(::Affine, ::Val{:μ}) = asℝ
2+
asparams(::Affine, ::Val{:σ}) = asℝ₊
3+
asparams(::Type{A}, ::Val{:μ}) where {A<:Affine} = asℝ
4+
asparams(::Type{A}, ::Val{:σ}) where {A<:Affine} = asℝ₊
5+
6+
asparams(::Affine{N,M,T}, ::Val{:μ}) where {N,M,T<:AbstractArray} = as(Array, asℝ, size(d.μ))
7+
asparams(::Affine{N,M,T}, ::Val{:σ}) where {N,M,T<:AbstractArray} = as(Array, asℝ, size(d.σ))

src/combinators/product.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
1-
2-
3-
4-
function TV.as(d::ProductMeasure{F,A}) where {F,A<:AbstractArray}
1+
function TV.as(d::ProductMeasure{F,S,A}) where {F,S,A<:AbstractArray}
52
d1 = marginals(d).f(first(marginals(d).data))
63
as(Array, as(d1), size(marginals(d))...)
74
end
85

96
###############################################################################
107
# I <: Base.Generator
118

12-
function TV.as(d::ProductMeasure{F,I}) where {F, I<:Base.Generator}
9+
function TV.as(d::ProductMeasure{F,S,I}) where {F,S,I<:Base.Generator}
1310
d1 = marginals(d).f(first(marginals(d).iter))
1411
as(Array, as(d1), size(marginals(d))...)
1512
end
1613

17-
14+
function TV.as(d::ProductMeasure{Returns{T},F,A}) where {T, F, A <: AbstractArray}
15+
as(Array, as(d.f.f.value), size(d.pars))
16+
end
1817

1918
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure, d1::Dists.Distribution) where {T}
2019
mar = marginals(d)
@@ -32,11 +31,11 @@ end
3231

3332

3433
# e.g. set(Normal(μ=2)^5, params, randn(5))
35-
function Accessors.set(d::ProductMeasure{F,A}, ::typeof(params), p::AbstractArray) where {F,A<:AbstractArray}
34+
function Accessors.set(d::ProductMeasure{F,S,A}, ::typeof(params), p::AbstractArray) where {F,S,A<:AbstractArray}
3635
set.(marginals(d), params, p)
3736
end
3837

39-
function Accessors.set(d::ProductMeasure{F,A}, ::typeof(params), p) where {F,A<:AbstractArray}
38+
function Accessors.set(d::ProductMeasure{F,S,A}, ::typeof(params), p) where {F,S,A<:AbstractArray}
4039
par = typeof(d.pars[1])(p)
4140
ProductMeasure(d.f, Fill(par, size(d.pars)))
4241
end

src/distproxy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ for m in keys(PROXIES)
2626
end
2727

2828

29-
Base.rand(rng::AbstractRNG, T::Type, d::ParameterizedMeasure) = rand(rng, distproxy(d))
29+
Base.rand(rng::AbstractRNG, ::Type{T}, d::ParameterizedMeasure) where {T} = rand(rng, distproxy(d))
3030

3131
# MonteCarloMeasurements.Particles(N::Int, d::AbstractMeasure) = MonteCarloMeasurements.Particles(N, distproxy(d))
3232

src/parameterized.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,3 @@ asparams(μ::ParameterizedMeasure, nt::NamedTuple=NamedTuple()) = asparams(const
7474

7575
TV.as(::Half) = asℝ₊
7676

77-
asparams(::Affine, ::Val{:μ}) = asℝ
78-
asparams(::Affine, ::Val{:σ}) = asℝ₊
79-
asparams(::Type{A}, ::Val{:μ}) where {A<:Affine} = asℝ
80-
asparams(::Type{A}, ::Val{:σ}) where {A<:Affine} = asℝ₊

src/parameterized/cauchy.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,23 @@ export Cauchy, HalfCauchy
66
@parameterized Cauchy(μ,σ) (1/π) * Lebesgue(ℝ)
77

88
@kwstruct Cauchy()
9+
@kwstruct Cauchy(μ)
10+
@kwstruct Cauchy(σ)
11+
@kwstruct Cauchy(μ,σ)
12+
@kwstruct Cauchy(ω)
13+
@kwstruct Cauchy(μ,ω)
914

10-
Cauchy(nt::NamedTuple{(:μ,:σ)}) = Affine(nt, Cauchy())
11-
Cauchy(nt::NamedTuple{(:μ,:ω)}) = Affine(nt, Cauchy())
12-
Cauchy(nt::NamedTuple{(:σ,)}) = Affine(nt, Cauchy())
13-
Cauchy(nt::NamedTuple{(:ω,)}) = Affine(nt, Cauchy())
14-
Cauchy(nt::NamedTuple{(:μ,)}) = Affine(nt, Cauchy())
1515

16-
@affinepars Cauchy
16+
17+
for N in AFFINEPARS
18+
@eval begin
19+
proxy(d::Cauchy{$N}) = affine(params(d), Cauchy())
20+
logdensity(d::Cauchy{$N}, x) = logdensity(proxy(d), x)
21+
basemeasure(d::Cauchy{$N}) = basemeasure(proxy(d))
22+
end
23+
end
24+
25+
# @affinepars Cauchy
1726

1827
function logdensity(d::Cauchy{()} , x)
1928
return -log1p(x^2)
@@ -34,3 +43,8 @@ TV.as(::Cauchy) = asℝ
3443
HalfCauchy(σ) = HalfCauchy=σ)
3544

3645
distproxy(d::Cauchy{()}) = Dists.Cauchy()
46+
distproxy(d::Cauchy{(:μ,)}) = Dists.Cauchy(d.μ, 1.0)
47+
distproxy(d::Cauchy{(:σ,)}) = Dists.Cauchy(0.0, d.σ)
48+
distproxy(d::Cauchy{(:μ,:σ)}) = Dists.Cauchy(d.μ, d.σ)
49+
distproxy(d::Cauchy{(:ω,)}) = Dists.Cauchy(0.0, inv(d.ω))
50+
distproxy(d::Cauchy{(:μ,:ω)}) = Dists.Cauchy(d.μ, inv(d.ω))

src/parameterized/gumbel.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,25 @@
22

33
export Gumbel
44

5-
@parameterized Gumbel() Lebesgue(ℝ)
5+
@parameterized Gumbel()
6+
7+
basemeasure(::Gumbel{()}) = Lebesgue(ℝ)
68

79
@kwstruct Gumbel()
810

9-
Gumbel(nt::NamedTuple{(:σ,)}) = Affine(nt, Gumbel())
11+
@kwstruct Gumbel(β)
1012

11-
@affinepars Gumbel
13+
@kwstruct Gumbel(μ,β)
1214

15+
# map affine names to those more common for Gumbel
16+
for N in [(,), (,), (,)]
17+
G = tuple(replace(collect(N), => )...)
18+
@eval begin
19+
proxy(d::Gumbel{$G}) = affine(NamedTuple{$N}(values(params(d))), Gumbel())
20+
logdensity(d::Gumbel{$G}, x) = logdensity(proxy(d), x)
21+
basemeasure(d::Gumbel{$G}) = basemeasure(proxy(d))
22+
end
23+
end
1324

1425
function logdensity(d::Gumbel{()} , x)
1526
return -exp(-x) - x

src/parameterized/laplace.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ export Laplace
55

66
@parameterized Laplace() (1/2) * Lebesgue(ℝ)
77

8-
Laplace(nt::NamedTuple{(:μ,:σ)}) = Affine(nt, Laplace())
9-
Laplace(nt::NamedTuple{(:μ,:ω)}) = Affine(nt, Laplace())
10-
Laplace(nt::NamedTuple{(:σ,)}) = Affine(nt, Laplace())
11-
Laplace(nt::NamedTuple{(:ω,)}) = Affine(nt, Laplace())
12-
Laplace(nt::NamedTuple{(:μ,)}) = Affine(nt, Laplace())
8+
for N in AFFINEPARS
9+
@eval begin
10+
proxy(d::Laplace{$N}) = affine(params(d), Laplace())
11+
logdensity(d::Laplace{$N}, x) = logdensity(proxy(d), x)
12+
basemeasure(d::Laplace{$N}) = basemeasure(proxy(d))
13+
end
14+
end
1315

14-
@affinepars Laplace
16+
# @affinepars Laplace
1517

1618

1719
function logdensity(d::Laplace{()} , x)
@@ -25,3 +27,8 @@ Base.rand(rng::AbstractRNG, μ::Laplace{()}) = rand(rng, Dists.Laplace())
2527
TV.as(::Laplace) = asℝ
2628

2729
distproxy(::Laplace{()}) = Dists.Laplace()
30+
distproxy(d::Laplace{(:μ,)}) = Dists.Laplace(d.μ, 1.0)
31+
distproxy(d::Laplace{(:σ,)}) = Dists.Laplace(0.0, d.σ)
32+
distproxy(d::Laplace{(:μ,:σ)}) = Dists.Laplace(d.μ, d.σ)
33+
distproxy(d::Laplace{(:ω,)}) = Dists.Laplace(0.0, inv(d.ω))
34+
distproxy(d::Laplace{(:μ,:ω)}) = Dists.Laplace(d.μ, inv(d.ω))

src/parameterized/mvnormal.jl

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,47 @@
11

22
# Multivariate Normal distribution
33

4-
using LinearAlgebra
54
export MvNormal
6-
using Random
7-
import Base
85

6+
@parameterized MvNormal(μ,σ)
97

10-
struct MvNormal{N, T, I, J} <: ParameterizedMeasure{N}
11-
par::NamedTuple{N, T}
12-
end
8+
# MvNormal(;kwargs...) = MvNormal(kwargs)
139

10+
@kwstruct MvNormal(μ)
11+
@kwstruct MvNormal(σ)
12+
@kwstruct MvNormal(ω)
13+
@kwstruct MvNormal(μ,σ)
14+
@kwstruct MvNormal(μ,ω)
1415

15-
function MvNormal(nt::NamedTuple{N,T}) where {N,T}
16-
I,J = mvnormaldims(nt)
16+
supportdim(d::MvNormal) = supportdim(params(d))
1717

18-
cache = Vector{Float64}(undef, max(I,J))
19-
MvNormal{N,T,I,J}(cache, nt)
20-
end
18+
proxy(d::MvNormal) = affine(params(d), Normal() ^ supportdim(d))
19+
logdensity(d::MvNormal, x) = logdensity(proxy(d), x)
20+
basemeasure(d::MvNormal) = basemeasure(proxy(d))
2121

22-
function Base.size(d::MvNormal{N, T, I, J}) where {N,T,I,J}
23-
return (I,)
24-
end
22+
rand(rng::AbstractRNG, ::Type{T}, d::MvNormal) where {T} = rand(rng, T, proxy(d))
2523

26-
mvnormaldims(nt::NamedTuple{(:A, :b)}) = size(nt.A)
27-
28-
function MeasureTheory.basemeasure::MvNormal{N, T, I,I}) where {N, T, I}
29-
return (1 / sqrt2π)^I * Lebesgue(ℝ)^I
30-
end
31-
32-
sampletype(d::MvNormal{N, T, I, J}) where {N,T,I,J} = Vector{Float64}
33-
34-
MvNormal(; kwargs...) = begin
35-
MvNormal((; kwargs...))
36-
end
37-
38-
39-
40-
function Random.rand!(rng::AbstractRNG, d::MvNormal{(:A, :b),T,I,J}, x::AbstractArray) where {T,I,J}
41-
z = getcache(d, J)
42-
rand!(rng, Normal()^J, z)
43-
copyto!(x, d.b)
44-
mul!(x, d.A, z, 1.0, 1.0)
45-
return x
46-
end
24+
# function MvNormal(nt::NamedTuple{(:μ,)})
25+
# dim = size(nt.μ)
26+
# affine(nt, Normal() ^ dim)
27+
# end
4728

48-
function logdensity(d::MvNormal{(:A,:b)}, x)
49-
cache = getcache(d, size(d))
50-
z = d.A \ (x - d.b)
51-
return logdensity(MvNormal(), z) - logabsdet(d.A)
52-
end
29+
# function MvNormal(nt::NamedTuple{(:σ,)})
30+
# dim = colsize(nt.σ)
31+
# affine(nt, Normal() ^ dim)
32+
# end
5333

54-
(::MvNormal, ::Lebesgue{ℝ}) = true
34+
# function MvNormal(nt::NamedTuple{(:ω,)})
35+
# dim = rowsize(nt.ω)
36+
# affine(nt, Normal() ^ dim)
37+
# end
5538

56-
# function logdensity(d::MvNormal{(:Σ⁻¹,)}, x)
57-
# @tullio ℓ = -0.5 * x[i] * d.Σ⁻¹[i,j] * x[j]
58-
# return ℓ
39+
# function MvNormal(nt::NamedTuple{(:μ, :σ,)})
40+
# dim = colsize(nt.σ)
41+
# affine(nt, Normal() ^ dim)
5942
# end
6043

61-
mvnormaldims(nt::NamedTuple{(:Σ⁻¹,)}) = size(nt.Σ⁻¹)
44+
# function MvNormal(nt::NamedTuple{(:μ, :ω,)})
45+
# dim = rowsize(nt.ω)
46+
# affine(nt, Normal() ^ dim)
47+
# end

0 commit comments

Comments
 (0)