Skip to content

Commit ee201c9

Browse files
authored
Mvnormal testvalue (#207)
* fix insupport(d::MvNormal, x) * insupport(d::OrthoLebesgue, x) * update tests * bump version * MvNormal work * formatting * MvNormal updates * add tests * inline * type stability * drop dead code
1 parent 7f028e7 commit ee201c9

File tree

5 files changed

+102
-32
lines changed

5 files changed

+102
-32
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureTheory"
22
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.16.1"
4+
version = "0.16.2"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

src/combinators/affine.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ Base.size(f::AffineTransform{(:μ, :λ)}) = size(f.λ)
4141
Base.size(f::AffineTransform{(:σ,)}) = size(f.σ)
4242
Base.size(f::AffineTransform{(:λ,)}) = size(f.λ)
4343

44+
LinearAlgebra.rank(f::AffineTransform{(:σ,)}) = rank(f.σ)
45+
LinearAlgebra.rank(f::AffineTransform{(:λ,)}) = rank(f.λ)
46+
LinearAlgebra.rank(f::AffineTransform{(:μ,:σ,)}) = rank(f.σ)
47+
LinearAlgebra.rank(f::AffineTransform{(:μ,:λ,)}) = rank(f.λ)
48+
4449
function Base.size(f::AffineTransform{(:μ,)})
4550
(n,) = size(f.μ)
4651
return (n, n)
@@ -132,9 +137,16 @@ struct OrthoLebesgue{N,T} <: PrimitiveMeasure
132137
OrthoLebesgue(nt::NamedTuple{N,T}) where {N,T} = new{N,T}(nt)
133138
end
134139

140+
function insupport(d::OrthoLebesgue, x)
141+
f = AffineTransform(d.par)
142+
finv = inverse(f)
143+
z = finv(x)
144+
f(z) x
145+
end
146+
135147
basemeasure(d::OrthoLebesgue) = d
136148

137-
logdensity_def(::OrthoLebesgue, x) = static(0)
149+
logdensity_def(::OrthoLebesgue, x) = static(0.0)
138150

139151
struct Affine{N,M,T} <: AbstractMeasure
140152
f::AffineTransform{N,T}
@@ -147,7 +159,7 @@ function Pretty.tile(d::Affine)
147159
Pretty.list_layout([pars, Pretty.tile(d.parent)]; prefix = :Affine)
148160
end
149161

150-
function testvalue(d::Affine)
162+
@inline function testvalue(d::Affine)
151163
f = getfield(d, :f)
152164
z = testvalue(parent(d))
153165
return f(z)

src/parameterized/mvnormal.jl

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,42 +8,77 @@ export MvNormal
88
# MvNormal(;kwargs...) = MvNormal(kwargs)
99

1010
@kwstruct MvNormal(μ)
11+
1112
@kwstruct MvNormal(σ)
1213
@kwstruct MvNormal(λ)
14+
1315
@kwstruct MvNormal(μ, σ)
1416
@kwstruct MvNormal(μ, λ)
1517

16-
supportdim(d::MvNormal) = supportdim(params(d))
18+
@kwstruct MvNormal(Σ)
19+
@kwstruct MvNormal(Λ)
20+
@kwstruct MvNormal(μ, Σ)
21+
@kwstruct MvNormal(μ, Λ)
1722

18-
@useproxy MvNormal
23+
as(d::MvNormal{(:μ,)}) = as(Array, length(d.μ))
1924

20-
proxy(d::MvNormal) = affine(params(d), Normal()^supportdim(d))
25+
as(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
26+
as(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
27+
as(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
28+
as(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
2129

22-
rand(rng::AbstractRNG, ::Type{T}, d::MvNormal) where {T} = rand(rng, T, proxy(d))
30+
function as(d::MvNormal{(:σ,),Tuple{M}}) where {M<:Triangular}
31+
σ = d.σ
32+
if @inbounds all(i -> σ[i] > 0, diagind(σ))
33+
return as(Array, size(σ, 1))
34+
else
35+
@error "Not implemented yet"
36+
end
37+
end
38+
39+
function as(d::MvNormal{(:λ,),Tuple{M}}) where {M<:Triangular}
40+
λ = d.λ
41+
if @inbounds all(i -> λ[i] > 0, diagind(λ))
42+
return as(Array, size(λ, 1))
43+
else
44+
@error "Not implemented yet"
45+
end
46+
end
2347

24-
insupport(::MvNormal, x) = true
48+
for N in setdiff(AFFINEPARS, [(,)])
49+
@eval begin
50+
function as(d::MvNormal{$N})
51+
p = proxy(d)
52+
if rank(getfield(p,:f)) == only(supportdim(d))
53+
return as(Array, supportdim(d))
54+
else
55+
@error "Not yet implemented"
56+
end
57+
end
58+
end
59+
end
2560

26-
# function MvNormal(nt::NamedTuple{(:μ,)})
27-
# dim = size(nt.μ)
28-
# affine(nt, Normal() ^ dim)
29-
# end
61+
supportdim(d::MvNormal) = supportdim(params(d))
62+
63+
supportdim(nt::NamedTuple{(:Σ,)}) = size(nt.Σ, 1)
64+
supportdim(nt::NamedTuple{(:μ,:Σ)}) = size(nt.Σ, 1)
65+
supportdim(nt::NamedTuple{(:Λ,)}) = size(nt.Λ, 1)
66+
supportdim(nt::NamedTuple{(:μ,:Λ)}) = size(nt.Λ, 1)
67+
68+
@useproxy MvNormal
3069

31-
# function MvNormal(nt::NamedTuple{(:σ,)})
32-
# dim = colsize(nt.σ)
33-
# affine(nt, Normal() ^ dim)
34-
# end
70+
for N in [(,), (,), (,), (,)]
71+
@eval basemeasure_depth(d::MvNormal{$N}) = static(2)
72+
end
3573

36-
# function MvNormal(nt::NamedTuple{(:λ,)})
37-
# dim = rowsize(nt.λ)
38-
# affine(nt, Normal() ^ dim)
39-
# end
74+
proxy(d::MvNormal) = affine(params(d), Normal()^supportdim(d))
75+
76+
rand(rng::AbstractRNG, ::Type{T}, d::MvNormal) where {T} = rand(rng, T, proxy(d))
4077

41-
# function MvNormal(nt::NamedTuple{(:μ, :σ,)})
42-
# dim = colsize(nt.σ)
43-
# affine(nt, Normal() ^ dim)
44-
# end
78+
insupport(d::MvNormal, x) = insupport(proxy(d), x)
4579

46-
# function MvNormal(nt::NamedTuple{(:μ, :λ,)})
47-
# dim = rowsize(nt.λ)
48-
# affine(nt, Normal() ^ dim)
49-
# end
80+
# Note: (C::Cholesky).L may or may not make a copy, depending on C.uplo, which is not included in the type
81+
@inline proxy(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = affine((σ = d.Σ.L,), Normal()^supportdim(d))
82+
@inline proxy(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = affine((λ = d.Λ.L,), Normal()^supportdim(d))
83+
@inline proxy(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, σ = d.Σ.L), Normal()^supportdim(d))
84+
@inline proxy(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))

src/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,19 @@ function func_string(f, types)
104104
return string(f)
105105
end
106106
end
107+
108+
function getL(C::Cholesky)
109+
Cfactors = getfield(C, :factors)
110+
Cuplo = getfield(C, :uplo)
111+
112+
LowerTriangular(Cuplo === 'L' ? Cfactors : Cfactors')
113+
end
114+
115+
function getU(C::Cholesky)
116+
Cfactors = getfield(C, :factors)
117+
Cuplo = getfield(C, :uplo)
118+
119+
UpperTriangular(Cuplo === 'U' ? Cfactors : Cfactors')
120+
end
121+
122+
const Triangular = Union{L,U} where {L<:LowerTriangular,U<:UpperTriangular}

test/runtests.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ function draw2(μ)
2323
return (x, y)
2424
end
2525

26-
function test_measure(μ)
27-
logdensity_def(μ, testvalue(μ)) isa AbstractFloat
28-
end
26+
x = randn(10,3)
27+
Σ = cholesky(x'*x)
28+
Λ = cholesky(inv(Σ))
29+
σ = MeasureTheory.getL(Σ)
30+
λ = MeasureTheory.getL(Λ)
2931

3032
test_measures = Any[
3133
# Chain(x -> Normal(μ=x), Normal(μ=0.0))
@@ -52,6 +54,12 @@ test_measures = Any[
5254
Normal(2, 3)
5355
Poisson(3.1)
5456
StudentT= 2.1)
57+
MvNormal= [1 0; 0 1; 1 1])
58+
MvNormal= [1 0 1; 0 1 1])
59+
MvNormal= Σ)
60+
MvNormal= Λ)
61+
MvNormal= σ)
62+
MvNormal= λ)
5563
Uniform()
5664
Counting(Float64)
5765
Dirac(0.0) + Normal()
@@ -67,7 +75,6 @@ testbroken_measures = Any[
6775
@testset "testvalue" begin
6876
for μ in test_measures
6977
@info "testing "
70-
@test test_measure(μ)
7178
test_interface(μ)
7279
end
7380

0 commit comments

Comments
 (0)