Skip to content

Commit 6680e77

Browse files
authored
Affine (#20)
* update Affine methods * add LinearAlgebra * get `logjac` working properly * bugfix * add tests * Try to make Julia 1.3 happy * bump version
1 parent 98326f3 commit 6680e77

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.4.1"
4+
version = "0.4.2"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
88
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1010
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
11+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1112
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1213
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
1314
MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900"

src/MeasureBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ include("primitives/trivial.jl")
6262

6363
include("combinators/factoredbase.jl")
6464
include("combinators/weighted.jl")
65-
include("combinators/affine.jl")
6665
include("combinators/superpose.jl")
6766
include("combinators/product.jl")
6867
include("combinators/for.jl")
6968
include("combinators/power.jl")
69+
include("combinators/affine.jl")
7070
include("combinators/spikemixture.jl")
7171
include("kernel.jl")
7272
include("combinators/likelihood.jl")

src/combinators/affine.jl

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export Affine, AffineTransform
2-
3-
struct AffineTransform{N,T}
2+
using LinearAlgebra
3+
@concrete terse struct AffineTransform{N,T}
44
par::NamedTuple{N,T}
55
end
66

@@ -22,16 +22,21 @@ Base.propertynames(d::AffineTransform{N}) where {N} = N
2222
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
2323
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
2424

25+
26+
logjac(x::AbstractMatrix) = first(logabsdet(x))
27+
28+
logjac(x::Number) = log(abs(x))
29+
2530
# TODO: `log` doesn't work for the multivariate case, we need the log absolute determinant
26-
logjac(f::AffineTransform{(:μ,:σ)}) = log(f.σ)
27-
logjac(f::AffineTransform{(:μ,:ω)}) = -log(f.ω)
28-
logjac(f::AffineTransform{(:σ,)}) = log(f.σ)
29-
logjac(f::AffineTransform{(:ω,)}) = -log(f.ω)
31+
logjac(f::AffineTransform{(:μ,:σ)}) = logjac(f.σ)
32+
logjac(f::AffineTransform{(:μ,:ω)}) = -logjac(f.ω)
33+
logjac(f::AffineTransform{(:σ,)}) = logjac(f.σ)
34+
logjac(f::AffineTransform{(:ω,)}) = -logjac(f.ω)
3035
logjac(f::AffineTransform{(:μ,)}) = 0.0
3136

3237
###############################################################################
3338

34-
struct Affine{N,M,T} <: AbstractMeasure
39+
@concrete terse struct Affine{N,M,T} <: AbstractMeasure
3540
f::AffineTransform{N,T}
3641
parent::M
3742
end
@@ -62,21 +67,40 @@ Base.propertynames(d::Affine{N}) where {N} = N ∪ (:parent,)
6267
end
6368
end
6469

65-
# Note: We could also write
66-
# logdensity(d::Affine, x) = logdensity(inv(getfield(d, :f)), x)
70+
Base.size(d) = size(d.μ)
71+
Base.size(d::Affine{(:σ,)}) = (size(d.σ, 1),)
72+
Base.size(d::Affine{(:ω,)}) = (size(d.ω, 2),)
6773

68-
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
69-
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
7074
logdensity(d::Affine{(:σ,)}, x) = logdensity(d.parent, d.σ \ x)
7175
logdensity(d::Affine{(:ω,)}, x) = logdensity(d.parent, d.ω * x)
7276
logdensity(d::Affine{(:μ,)}, x) = logdensity(d.parent, x - d.μ)
77+
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
78+
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
79+
80+
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
81+
function logdensity(d::Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}}, x)
82+
z = x - d.μ
83+
ldiv!(d.σ, z)
84+
logdensity(d.parent, z)
85+
end
86+
87+
# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
88+
function logdensity(d::Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}}, x)
89+
z = x - d.μ
90+
lmul!(d.ω, z)
91+
logdensity(d.parent, z)
92+
end
7393

7494
basemeasure(d::Affine) = affine(getfield(d, :f), basemeasure(d.parent))
7595

7696
# We can't do this until we know we're working with Lebesgue measure, since for
7797
# example it wouldn't make sense to apply a log-Jacobian to a point measure
7898
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = weightedmeasure(-logjac(d), d.parent)
7999

100+
function basemeasure(d::Affine{N,L}) where {N, L<:PowerMeasure{typeof(identity), <:Lebesgue}}
101+
weightedmeasure(-logjac(d), d.parent)
102+
end
103+
80104
logjac(d::Affine) = logjac(getfield(d, :f))
81105

82106

test/runtests.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ end
134134
@test Affine(par)(unif) == Affine(f, unif)
135135
@test density(Affine(f, Affine(inv(f), unif)), 0.5) == 1
136136
end
137+
138+
d = ∫exp(x -> -x^2, Lebesgue(ℝ))
139+
140+
μ = randn(3)
141+
σ = LowerTriangular(randn(3,3))
142+
ω = inv(σ)
143+
144+
x = randn(3)
145+
146+
@test logdensity(Affine((μ=μ,σ=σ), d^3), x) logdensity(Affine((μ=μ,ω=ω), d^3), x)
147+
@test logdensity(Affine((σ=σ,), d^3), x) logdensity(Affine((ω=ω,), d^3), x)
148+
@test logdensity(Affine((μ=μ,), d^3), x) logdensity(d^3, x-μ)
149+
137150
end
138151

139152
@testset "For" begin

0 commit comments

Comments
 (0)