Skip to content

Commit 1d1ad53

Browse files
theogfcscherrer
andauthored
Some convenience functions for WeightedMeasures (#43)
* convenience functions for WeightedMeasures * Fix typo * Second typo * Add tests and use _logweight * Update test/combinators/weighted.jl Co-authored-by: Chad Scherrer <[email protected]> * Remove length Co-authored-by: Chad Scherrer <[email protected]>
1 parent c533f6d commit 1d1ad53

File tree

4 files changed

+35
-6
lines changed

4 files changed

+35
-6
lines changed

src/combinators/smart-constructors.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ function productmeasure(f::Returns{FB}, param_maps, pars) where {FB<:FactoredBas
6262
end
6363

6464
function productmeasure(f::Returns{W}, ::typeof(identity), pars) where {W<:WeightedMeasure}
65-
= f.value.logweight
66-
base = f.value.base
65+
= _logweight(f.value)
66+
base = basemeasure(f.value)
6767
newbase = productmeasure(Returns(base), identity, pars)
6868
weightedmeasure(length(pars) * ℓ, newbase)
6969
end
@@ -102,7 +102,7 @@ function weightedmeasure(ℓ::R, b::M) where {R,M}
102102
end
103103

104104
function weightedmeasure(ℓ, b::WeightedMeasure)
105-
weightedmeasure(ℓ + b.logweight, b.base)
105+
weightedmeasure(ℓ + _logweight(b), b.base)
106106
end
107107

108108
###############################################################################

src/combinators/weighted.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,27 @@ export WeightedMeasure, AbstractWeightedMeasure
99

1010
abstract type AbstractWeightedMeasure <: AbstractMeasure end
1111

12-
logweight::AbstractWeightedMeasure) = μ.logweight
13-
basemeasure::AbstractWeightedMeasure) = μ.base
12+
# By default the weight for all measure is 1
13+
_logweight(::AbstractMeasure) = 0
1414

15-
@inline function logdensity_def(d::AbstractWeightedMeasure, x)
15+
@inline function logdensity_def(d::AbstractWeightedMeasure, _)
1616
d.logweight
1717
end
1818

19+
function Base.rand(rng::AbstractRNG, ::Type{T}, μ::AbstractWeightedMeasure) where {T}
20+
rand(rng, T, basemeasure(μ))
21+
end
22+
1923
###############################################################################
2024

2125
struct WeightedMeasure{R,M} <: AbstractWeightedMeasure
2226
logweight::R
2327
base::M
2428
end
2529

30+
_logweight::WeightedMeasure) = μ.logweight
31+
basemeasure::AbstractWeightedMeasure) = μ.base
32+
2633
function Base.show(io::IO, μ::WeightedMeasure)
2734
io = IOContext(io, :compact => true)
2835
print(io, exp.logweight), " * ", μ.base)

test/combinators/weighted.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Random: MersenneTwister
2+
using Test
3+
4+
using MeasureBase
5+
using MeasureBase: _logweight, weightedmeasure, WeightedMeasure
6+
7+
@testset "weighted" begin
8+
@test iszero(_logweight(Lebesgue(ℝ)))
9+
μ₀ = Dirac(0.0)
10+
w = 2.0
11+
μ = @inferred w * μ₀
12+
@test μ == WeightedMeasure(log(w), μ₀) == weightedmeasure(log(w), μ₀)
13+
@test μ isa WeightedMeasure
14+
@test _logweight(μ) == log(w)
15+
@test _logweight(w * μ) == 2 * log(w)
16+
@test rand(MersenneTwister(123), μ) == rand(MersenneTwister(123), μ₀)
17+
x = rand()
18+
@test logdensity_def(μ, x) == log(w)
19+
@test logdensityof(μ, x) == logdensityof(μ₀, x)
20+
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,5 @@ end
212212
# @test f(x) ≈ x^2
213213
# end
214214
end
215+
216+
include("combinators/weighted.jl")

0 commit comments

Comments
 (0)