Skip to content

Commit ddc0c3d

Browse files
theogfcscherrer
andauthored
Fix a few issues with functions on superposition (#46)
* Show that the logdensity is not defined when not using 2-tuples * Fix on basemeasure * Remove Any * WIP for constructors * Add gitignore for vscode * Add tests fors the constructor * Update src/combinators/smart-constructors.jl * Update src/combinators/superpose.jl * Update test/combinators/superpose.jl Co-authored-by: Chad Scherrer <[email protected]>
1 parent 8744525 commit ddc0c3d

File tree

4 files changed

+78
-19
lines changed

4 files changed

+78
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
**/Manifest.toml
55
/docs/build/
66
coverage/
7+
.vscode/settings.json

src/combinators/smart-constructors.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,17 @@ end
3434

3535
productmeasure(mar::Fill) = powermeasure(mar.value, mar.axes)
3636

37-
function productmeasure(mar::ReadonlyMappedArray{T, N, A, Returns{M}}) where {T,N,A,M}
37+
function productmeasure(mar::ReadonlyMappedArray{T,N,A,Returns{M}}) where {T,N,A,M}
3838
return powermeasure(mar.f.value, axes(mar.data))
3939
end
4040

4141
productmeasure(mar::Base.Generator) = ProductMeasure(mar)
4242
productmeasure(mar::AbstractArray) = ProductMeasure(mar)
4343

44-
4544
# TODO: Make this static when its length is static
46-
@inline function productmeasure(mar::AbstractArray{WeightedMeasure{StaticFloat64{W}, M}}) where {W,M}
45+
@inline function productmeasure(
46+
mar::AbstractArray{WeightedMeasure{StaticFloat64{W},M}},
47+
) where {W,M}
4748
return weightedmeasure(W * length(mar), productmeasure(map(basemeasure, mar)))
4849
end
4950

@@ -54,6 +55,7 @@ productmeasure(f, param_maps, pars) = ProductMeasure(kernel(f, param_maps), pars
5455

5556
productmeasure(k::ParameterizedTransitionKernel, pars) = productmeasure(k.f, k.param_maps, pars)
5657

58+
5759
function productmeasure(f::Returns{W}, ::typeof(identity), pars) where {W<:WeightedMeasure}
5860
ℓ = _logweight(f.value)
5961
base = basemeasure(f.value)
@@ -75,19 +77,31 @@ superpose(a::AbstractArray) = SuperpositionMeasure(a)
7577
superpose(t::Tuple) = SuperpositionMeasure(t)
7678
superpose(nt::NamedTuple) = SuperpositionMeasure(nt)
7779

78-
function superpose(μ::T, ν::T) where {T}
79-
if μ==ν
80+
function superpose(μ::T, ν::T) where {T<:AbstractMeasure}
81+
if μ == ν
8082
return weightedmeasure(static(logtwo), μ)
8183
else
8284
return superpose((μ, ν))
8385
end
8486
end
8587

86-
function superpose(μ, ν)
87-
components = (μ, ν)
88-
superpose(components)
88+
function superpose(μ::AbstractMeasure, μs...)
89+
if all(==(μ), μs)
90+
return weightedmeasure(log(length(μs) + 1), μ)
91+
else
92+
return superpose((μ, μs...))
93+
end
94+
end
95+
96+
add_measures(μs::AbstractVector, νs) = push!(μs, νs...)
97+
add_measures(μs::Tuple, νs) = (μs..., νs...)
98+
99+
function superpose(μ::SuperpositionMeasure, μs...)
100+
SuperpositionMeasure(add_measures(μ.components, μs))
89101
end
90102

103+
superpose(μ::SuperpositionMeasure) = μ
104+
91105
###############################################################################
92106
# WeightedMeasure
93107

@@ -125,5 +139,5 @@ function kernel(::Type{M}; param_maps...) where {M}
125139
kernel(M, nt)
126140
end
127141

128-
129142
kernel(k::ParameterizedTransitionKernel) = k
143+

src/combinators/superpose.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
2+
using LogarithmicNumbers
3+
using LogExpFunctions
4+
15
export SuperpositionMeasure
26

37
@doc raw"""
4-
struct SuperpositionMeasure{X,NT} <: AbstractMeasure
8+
struct SuperpositionMeasure{NT} <: AbstractMeasure
59
components :: NT
610
end
711
Superposition of measures is analogous to mixture distributions, but (because
@@ -57,16 +61,14 @@ function Base.:+(μ::AbstractMeasure, ν::AbstractMeasure)
5761
superpose(μ, ν)
5862
end
5963

60-
using LogarithmicNumbers
61-
6264
oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
6365

6466
@inline function density_def(s::SuperpositionMeasure{Tuple{A,B}}, x) where {A,B}
6567
(μ, ν) = s.components
6668

6769
insupport(μ, x) || return exp(ULogarithmic, logdensity_def(ν, x))
6870
insupport(ν, x) || return exp(ULogarithmic, logdensity_def(μ, x))
69-
71+
7072
α = basemeasure(μ)
7173
β = basemeasure(ν)
7274
dμ_dα = exp(ULogarithmic, logdensity_def(μ, x))
@@ -76,13 +78,24 @@ oneplus(x::ULogarithmic) = exp(ULogarithmic, log1pexp(x.log))
7678
return dμ_dα / oneplus(dβ_dα) + dν_dβ / oneplus(dα_dβ)
7779
end
7880

79-
using LogExpFunctions
81+
function density_def(s::SuperpositionMeasure, x)
82+
T = typeof(s)
83+
msg = """
84+
Not implemented: There is no method
85+
density_def(::$T, x)
86+
"""
87+
error(msg)
88+
end
8089

81-
@inline function logdensity_def(μ::T, ν::T, x::Any) where T<:(SuperpositionMeasure{Tuple{A, B}} where {A, B})
90+
@inline function logdensity_def(
91+
μ::T,
92+
ν::T,
93+
x,
94+
) where {T<:(SuperpositionMeasure{Tuple{A,B}} where {A,B})}
8295
if μ === ν
8396
return zero(return_type(logdensity_def, (μ, x)))
8497
else
85-
return logdensity_def(μ,x) - logdensity_def(ν, x)
98+
return logdensity_def(μ, x) - logdensity_def(ν, x)
8699
end
87100
end
88101

@@ -94,7 +107,11 @@ end
94107
return logaddexp(logdensity_rel(μ, β, x), logdensity_rel(ν, β, x))
95108
end
96109

97-
@inline function logdensity_def(s::SuperpositionMeasure{Tuple{A,B}}, β::SuperpositionMeasure, x) where {A,B}
110+
@inline function logdensity_def(
111+
s::SuperpositionMeasure{Tuple{A,B}},
112+
β::SuperpositionMeasure,
113+
x,
114+
) where {A,B}
98115
(μ, ν) = s.components
99116
insupport(μ, x) == true || return logdensity_rel(ν, β, x)
100117
insupport(ν, x) == true || return logdensity_rel(μ, β, x)
@@ -107,7 +124,10 @@ end
107124

108125
@inline logdensity_def(s::SuperpositionMeasure, x) = log(density_def(s, x))
109126

110-
basemeasure(μ::SuperpositionMeasure) = superpose(map(basemeasure, μ.components)...)
127+
function basemeasure(μ::SuperpositionMeasure{Tuple{A,B}}) where {A,B}
128+
superpose(map(basemeasure, μ.components)...)
129+
end
130+
basemeasure(μ::SuperpositionMeasure) = superpose(map(basemeasure, μ.components))
111131

112132
# TODO: Fix `rand` method (this one is wrong)
113133
# function Base.rand(μ::SuperpositionMeasure{X,N}) where {X,N}
@@ -118,4 +138,4 @@ basemeasure(μ::SuperpositionMeasure) = superpose(map(basemeasure, μ.components
118138
any(d.components) do c
119139
dynamic(insupport(c, x))
120140
end
121-
end
141+
end

test/combinators/superpose.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using Test
2+
3+
using MeasureBase
4+
using MeasureBase: superpose
5+
6+
@testset "superpose.jl" begin
7+
μ = Dirac(0)
8+
ν = Dirac(1)
9+
μs = μ + ν
10+
@test μs isa SuperpositionMeasure{<:Tuple{Dirac,Dirac}}
11+
@test μs == SuperpositionMeasure((μ, ν)) == superpose(μ, ν)
12+
@test density_def(μs, 0) == 1.0
13+
@test basemeasure(μs) == CountingMeasure() + CountingMeasure()
14+
15+
μs = SuperpositionMeasure([μ, ν])
16+
@test μs isa SuperpositionMeasure{<:AbstractVector{<:AbstractMeasure}}
17+
@test_throws ErrorException density_def(μs, 0)
18+
@test basemeasure(μs).components == SuperpositionMeasure([CountingMeasure(), CountingMeasure()]).components
19+
20+
μ2 = μ + μ
21+
@test μ2 isa WeightedMeasure
22+
@test μ2 == superpose(μ, μ)
23+
@test basemeasure(μ2) == μ
24+
end

0 commit comments

Comments
 (0)