Skip to content

Commit ee9056a

Browse files
committed
more affine stuff
1 parent 4e7a08f commit ee9056a

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/combinators/affine.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,24 @@ Base.propertynames(d::Affine{N}) where {N} = N ∪ (:parent,)
3636
end
3737
end
3838

39-
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ)) - log(d.σ)
40-
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ)) + log(d.ω)
39+
# Note: We could also write
40+
# logdensity(d::Affine, x) = logdensity(inv(getfield(d, :f)), x)
4141

42-
logdensity(d::Affine{(:σ,)}, x) = logdensity(d.parent, d.σ \ x) - log(d.σ)
43-
logdensity(d::Affine{(:ω,)}, x) = logdensity(d.parent, d.ω * x) + log(d.ω)
42+
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
43+
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
44+
logdensity(d::Affine{(:σ,)}, x) = logdensity(d.parent, d.σ \ x)
45+
logdensity(d::Affine{(:ω,)}, x) = logdensity(d.parent, d.ω * x)
4446
logdensity(d::Affine{(:μ,)}, x) = logdensity(d.parent, x - d.μ)
47+
48+
basemeasure(d::Affine) = Affine(d.f, basemeasure(d.parent))
49+
50+
basemeasure(d::Affine{N,L}) where {N, L<:Lebesgue} = d.parent
51+
52+
logdensity(d::Affine{N,L}, x) where {N,L<:Lebesgue} = logjac(getfield(d, :f))
53+
54+
# TODO: `log` doesn't work for the multivariate case, we need the log absolute determinant
55+
logjac(::AffineTransform{(:μ,:σ)}) = -log(d.σ)
56+
logjac(::AffineTransform{(:μ,:ω)}) = log(d.ω)
57+
logjac(::AffineTransform{(:σ,)}) = -log(d.σ)
58+
logjac(::AffineTransform{(:ω,)}) = log(d.ω)
59+
logjac(::AffineTransform{(:μ,)}) = 0.0

src/macros.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ macro parameterized(expr)
130130
_parameterized(__module__, expr)
131131
end
132132

133+
export @affinepars
134+
135+
macro affinepars(expr)
136+
_affinepars(__module__, expr)
137+
end
138+
139+
function _affinepars(__module__, ex)
140+
ex = esc(ex)
141+
quote
142+
function Base.show(io::IO, d::Affine{N, <:$ex}) where N
143+
nt1 = getfield(d.parent, :par)
144+
nt2 = getfield(getfield(d, :f), :par)
145+
par = merge(nt1, nt2)
146+
print(io, $ex, par)
147+
end
148+
end
149+
end
150+
133151
"""
134152
@half dist([paramnames])
135153

0 commit comments

Comments
 (0)