Skip to content

Commit 13658b3

Browse files
committed
Affine and AffineTransform
1 parent ef7bd92 commit 13658b3

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/MeasureBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ include("primitives/lebesgue.jl")
4848
include("primitives/dirac.jl")
4949
include("primitives/trivial.jl")
5050

51+
include("combinators/affine.jl")
5152
include("combinators/weighted.jl")
5253
include("combinators/superpose.jl")
5354
include("combinators/product.jl")

src/combinators/affine.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
export Affine, AffineTransform
2+
3+
struct AffineTransform{N,T}
4+
par::NamedTuple{N,T}
5+
end
6+
7+
@inline Base.getproperty(d::AffineTransform, s::Symbol) = getfield(getfield(d, :par), s)
8+
9+
Base.propertynames(d::AffineTransform{N}) where {N} = N
10+
11+
@inline Base.inv(f::AffineTransform{(:μ,:σ)}) = AffineTransform((μ = -(f.σ \ f.μ), ω = f.σ))
12+
@inline Base.inv(f::AffineTransform{(:μ,:ω)}) = AffineTransform((μ = - f.ω * f.μ, σ = f.ω))
13+
14+
(f::AffineTransform{(:μ,:σ)})(x) = f.σ * x + f.μ
15+
16+
(f::AffineTransform{(:μ,:ω)})(x) = f.ω \ x + f.μ
17+
18+
###############################################################################
19+
20+
struct Affine{N,M,T} <: AbstractMeasure
21+
f::AffineTransform{N,T}
22+
parent::M
23+
end
24+
25+
Affine(nt::NamedTuple, μ::AbstractMeasure) = Affine(AffineTransform(nt), μ)
26+
27+
Affine(nt::NamedTuple) = μ -> Affine(nt, μ)
28+
29+
Base.propertynames(d::Affine{N}) where {N} = N (:parent,)
30+
31+
@inline function Base.getproperty(d::Affine, s::Symbol)
32+
if s === :parent
33+
return getfield(d, :parent)
34+
else
35+
return getfield(getfield(d, :f), s)
36+
end
37+
end
38+
39+
logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ)) - log(d.σ)
40+
41+
logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ)) + log(d.ω)

0 commit comments

Comments
 (0)