Skip to content

Commit d22485d

Browse files
authored
logjac for non-square matrices (#21)
* logjac for non-square matrices * `rand!` for Affine measures * testvalue(::Affine) * check for factorization * drop redundant method * `apply!` method for Factorization * bugfix * drop `rand` for now * roll back `rand` for now * logdensity for product measures * bugfix * logpdf * inlining * add f to propertynames * drop unneeded method * affine test * tests * typo * version bump
1 parent 6680e77 commit d22485d

File tree

12 files changed

+137
-23
lines changed

12 files changed

+137
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module MeasureBase
33
const logtwo = log(2.0)
44

55
using Random
6+
import Random: rand!
67

78
using ConcreteStructs
89
using MLStyle

src/combinators/affine.jl

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,65 @@ Base.propertynames(d::AffineTransform{N}) where {N} = N
1616
@inline Base.inv(f::AffineTransform{(:ω,)}) = AffineTransform((σ = f.ω,))
1717
@inline Base.inv(f::AffineTransform{(:μ,)}) = AffineTransform((μ = -f.μ,))
1818

19+
# `size(f) == (m,n)` means `f : ℝⁿ → ℝᵐ`
20+
Base.size(f::AffineTransform{(:μ,:σ)}) = size(f.σ)
21+
Base.size(f::AffineTransform{(:μ,:ω)}) = size(f.ω)
22+
Base.size(f::AffineTransform{(:σ,)}) = size(f.σ)
23+
Base.size(f::AffineTransform{(:ω,)}) = size(f.ω)
24+
25+
function Base.size(f::AffineTransform{(:μ,)})
26+
(n,) = size(f.μ)
27+
return (n,n)
28+
end
29+
30+
Base.size(f::AffineTransform, n::Int) = @inbounds size(f)[n]
31+
1932
(f::AffineTransform{(:μ,)})(x) = x + f.μ
2033
(f::AffineTransform{(:σ,)})(x) = f.σ * x
2134
(f::AffineTransform{(:ω,)})(x) = f.ω \ x
2235
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
2336
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
2437

38+
@inline function apply!(x, f::AffineTransform{(:μ,)}, z)
39+
x .= z .+ f.μ
40+
return x
41+
end
42+
43+
@inline function apply!(x, f::AffineTransform{(:σ,)}, z)
44+
mul!(x, f.σ, z)
45+
return x
46+
end
47+
48+
@inline function apply!(x, f::AffineTransform{(:ω,), Tuple{F}}, z) where {F<:Factorization}
49+
ldiv!(x, f.ω, z)
50+
return x
51+
end
52+
53+
@inline function apply!(x, f::AffineTransform{(:ω,)}, z)
54+
ldiv!(x, factorize(f.ω), z)
55+
return x
56+
end
57+
58+
@inline function apply!(x, f::AffineTransform{(:μ,:σ)}, z)
59+
apply!(x, AffineTransform((σ = f.σ,)))
60+
apply!(x, AffineTransform((μ = f.μ,)))
61+
return x
62+
end
63+
64+
@inline function apply!(x, f::AffineTransform{(:μ,:ω)}, z)
65+
apply!(x, AffineTransform((ω = f.ω,)))
66+
apply!(x, AffineTransform((μ = f.μ,)))
67+
return x
68+
end
69+
70+
function logjac(x::AbstractMatrix)
71+
(m,n) = size(x)
72+
m == n && return first(logabsdet(x))
2573

26-
logjac(x::AbstractMatrix) = first(logabsdet(x))
74+
# Equivalent to sum(log, svdvals(x)), but much faster
75+
m > n && return first(logabsdet(x' * x)) / 2
76+
return first(logabsdet(x * x')) / 2
77+
end
2778

2879
logjac(x::Number) = log(abs(x))
2980

@@ -41,6 +92,12 @@ logjac(f::AffineTransform{(:μ,)}) = 0.0
4192
parent::M
4293
end
4394

95+
function testvalue(d::Affine)
96+
f = getfield(d, :f)
97+
z = testvalue(parent(d))
98+
return f(z)
99+
end
100+
44101
Affine(nt::NamedTuple, μ::AbstractMeasure) = affine(nt, μ)
45102

46103
Affine(nt::NamedTuple) = affine(nt)
@@ -57,17 +114,19 @@ function paramnames(::Type{A}) where {N,M, A<:Affine{N,M}}
57114
tuple(union(N, paramnames(M))...)
58115
end
59116

60-
Base.propertynames(d::Affine{N}) where {N} = N (:parent,)
117+
Base.propertynames(d::Affine{N}) where {N} = N (:parent,:f)
61118

62119
@inline function Base.getproperty(d::Affine, s::Symbol)
63120
if s === :parent
64121
return getfield(d, :parent)
122+
elseif s === :f
123+
return getfield(d, :f)
65124
else
66125
return getproperty(getfield(d, :f), s)
67126
end
68127
end
69128

70-
Base.size(d) = size(d.μ)
129+
Base.size(d::Affine) = size(d.μ)
71130
Base.size(d::Affine{(:σ,)}) = (size(d.σ, 1),)
72131
Base.size(d::Affine{(:ω,)}) = (size(d.ω, 2),)
73132

@@ -78,14 +137,19 @@ logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
78137
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
79138

80139
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
81-
function logdensity(d::Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}}, x)
140+
@inline function logdensity(d::Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}}, x)
82141
z = x - d.μ
83-
ldiv!(d.σ, z)
142+
σ = d.σ
143+
if σ isa Factorization
144+
ldiv!(σ, z)
145+
else
146+
ldiv!(factorize(σ), z)
147+
end
84148
logdensity(d.parent, z)
85149
end
86150

87151
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
88-
function logdensity(d::Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}}, x)
152+
@inline function logdensity(d::Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}}, x)
89153
z = x - d.μ
90154
lmul!(d.ω, z)
91155
logdensity(d.parent, z)
@@ -103,9 +167,17 @@ end
103167

104168
logjac(d::Affine) = logjac(getfield(d, :f))
105169

106-
107-
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
108-
z = rand(rng, T, parent(d))
170+
function Random.rand!(rng::Random.AbstractRNG, d::Affine, x::AbstractVector{T}, z=Vector{T}(undef, size(getfield(d,:f),2))) where {T}
171+
rand!(rng, parent(d), z)
109172
f = getfield(d, :f)
110-
return f(z)
173+
apply!(x, f, z)
174+
return x
111175
end
176+
177+
178+
# function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
179+
# f = getfield(d, :f)
180+
# z = rand(rng, T, parent(d))
181+
# apply!(x, f, z)
182+
# return z
183+
# end

src/combinators/factoredbase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ struct FactoredBase{R,C,V,B} <: AbstractMeasure
77
base :: B
88
end
99

10-
function logdensity(d::FactoredBase, x)
10+
@inline function logdensity(d::FactoredBase, x)
1111
d.inbounds(x) || return -Inf
1212
d.const+ d.varℓ()
1313
end

src/combinators/pointwise.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Base.length(m::PointwiseProductMeasure{T}) where {T} = length(m.data)
2727

2828
(args...) = pointwiseproduct(args...)
2929

30-
function logdensity(d::PointwiseProductMeasure, x)
30+
@inline function logdensity(d::PointwiseProductMeasure, x)
3131
sum((logdensity(dⱼ, x) for dⱼ in d.data))
3232
end
3333

src/combinators/power.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ params(::Type{P}) where {F,P<:ProductMeasure{F,<:Fill}} = params(D)
7878
end
7979

8080
# Same as PowerMeasure
81-
function logdensity(d::ProductMeasure{F,<:Fill}, x) where {F}
81+
@inline function logdensity(d::ProductMeasure{F,<:Fill}, x) where {F}
8282
d1 = d.f(first(d.pars))
8383
sum(xj -> logdensity(d1, xj), x)
8484
end

src/combinators/product.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function Base.show(io::IO, μ::ProductMeasure{F,T}) where {F,T <: Tuple}
6262
print(io, join(string.(marginals(μ)), ""))
6363
end
6464

65-
function logdensity(d::ProductMeasure{F,T}, x::Tuple) where {F,T<:Tuple}
65+
@inline function logdensity(d::ProductMeasure{F,T}, x::Tuple) where {F,T<:Tuple}
6666
mapreduce(logdensity, +, d.f.(d.pars), x)
6767
end
6868

@@ -192,3 +192,9 @@ end
192192
# function Accessors.set(d::ProductMeasure{F,T}, ::typeof(params), p::Tuple) where {F, T<:Tuple}
193193
# set.(marginals(d), params, p)
194194
# end
195+
196+
# function logdensity(μ::ProductMeasure, ν::ProductMeasure, x)
197+
# sum(zip(marginals(μ), marginals(ν), x)) do μ_ν_x
198+
# logdensity(μ_ν_x...)
199+
# end
200+
# end

src/combinators/restricted.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ struct RestrictedMeasure{F,M} <: AbstractMeasure
33
base::M
44
end
55

6-
function logdensity(d::RestrictedMeasure, x)
6+
@inline function logdensity(d::RestrictedMeasure, x)
77
d.f(x) || return -Inf
8+
return 0.0
89
end
910

1011
function density(d::RestrictedMeasure, x)
1112
d.f(x) || return 0.0
13+
return 1.0
1214
end
1315

1416
basemeasure::RestrictedMeasure) = μ.base

src/density.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ Define a new measure in terms of a density `f` over some measure `base`.
9898

9999
# TODO: `density` and `logdensity` functions for `DensityMeasure`
100100

101-
function logdensity::T, ν::T, x) where {T<:AbstractMeasure}
101+
@inline function logdensity::T, ν::T, x) where {T<:AbstractMeasure}
102102
μ==ν && return 0.0
103103
invoke(logdensity, Tuple{AbstractMeasure, AbstractMeasure, typeof(x)}, μ, ν, x)
104104
end
105105

106-
function logdensity::AbstractMeasure, ν::AbstractMeasure, x)
106+
@inline function logdensity::AbstractMeasure, ν::AbstractMeasure, x)
107107
α = basemeasure(μ)
108108
β = basemeasure(ν)
109109

@@ -135,6 +135,17 @@ function logdensity(μ::AbstractMeasure, ν::AbstractMeasure, x)
135135
return
136136
end
137137

138+
function logpdf(d::AbstractMeasure, x)
139+
_logpdf(d, basemeasure(d), x, zero(Float64))
140+
end
141+
142+
@inline function _logpdf(d::AbstractMeasure, β::AbstractMeasure, x, ℓ::Float64)
143+
d === β && return
144+
Δℓ = logdensity(d, x)
145+
# @show Δℓ, d
146+
_logpdf(β, basemeasure(β), x, ℓ + Δℓ)
147+
end
148+
138149
logdensity(::Lebesgue, ::Lebesgue, x) = 0.0
139150

140151
# logdensity(::Lebesgue{ℝ}, ::Lebesgue{ℝ}, x) = zero(x)

src/rand.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,19 @@ Base.rand(T::Type, μ::AbstractMeasure) = rand(Random.GLOBAL_RNG, T, μ)
66

77
Base.rand(rng::AbstractRNG, d::AbstractMeasure) = rand(rng, Float64, d)
88

9-
@inline Random.rand!(d::AbstractMeasure, arr::AbstractArray) = rand!(GLOBAL_RNG, d, arr)
10-
11-
9+
@inline Random.rand!(d::AbstractMeasure, args...) = rand!(GLOBAL_RNG, d, args...)
10+
11+
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
12+
z = rand(rng, T, parent(d))
13+
f = getfield(d, :f)
14+
return f(z)
15+
end
16+
17+
# TODO: Make this work
18+
# function Base.rand(rng::AbstractRNG, ::Type{T}, d::AbstractMeasure) where {T}
19+
# x = testvalue(d)
20+
# rand!(d, x)
21+
# end
1222

1323
# struct ArraySlot{A,I}
1424
# arr::A

0 commit comments

Comments
 (0)