Skip to content

Commit 0298d66

Browse files
authored
bugfix in weighted measures (#13)
* bugfix in weighted measures * Canonical structure for Affine/Weighted composition * bump version
1 parent f0c1aaf commit 0298d66

File tree

5 files changed

+20
-4
lines changed

5 files changed

+20
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.3.6"
4+
version = "0.3.7"
55

66
[deps]
77
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"

src/MeasureBase.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ include("primitives/lebesgue.jl")
4949
include("primitives/dirac.jl")
5050
include("primitives/trivial.jl")
5151

52-
include("combinators/affine.jl")
5352
include("combinators/weighted.jl")
53+
include("combinators/affine.jl")
5454
include("combinators/superpose.jl")
5555
include("combinators/product.jl")
5656
include("combinators/for.jl")

src/combinators/affine.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ struct AffineTransform{N,T}
44
par::NamedTuple{N,T}
55
end
66

7+
params(f::AffineTransform) = getfield(f, :par)
8+
79
@inline Base.getproperty(d::AffineTransform, s::Symbol) = getfield(getfield(d, :par), s)
810

911
Base.propertynames(d::AffineTransform{N}) where {N} = N
@@ -25,6 +27,12 @@ Base.propertynames(d::AffineTransform{N}) where {N} = N
2527
struct Affine{N,M,T} <: AbstractMeasure
2628
f::AffineTransform{N,T}
2729
parent::M
30+
31+
function Affine(f::AffineTransform, parent::WeightedMeasure)
32+
WeightedMeasure(parent.logweight, Affine(f, parent.base))
33+
end
34+
35+
Affine(f::AffineTransform{N,T}, parent::M) where {N,M,T} = new{N,M,T}(f, parent)
2836
end
2937

3038
parent(d::Affine) = getfield(d, :parent)

src/combinators/half.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ function Base.rand(rng::AbstractRNG, T::Type, μ::Half)
1717
return abs(rand(rng, T, unhalf(μ)))
1818
end
1919

20-
logdensity::Half, x) = (x < 0) * (-Inf) + (x > 0) * logdensity(unhalf(μ), x)
20+
logdensity::Half, x) = x > 0 ? logdensity(unhalf(μ), x) : -Inf

src/combinators/weighted.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ end
2222
struct WeightedMeasure{R,M} <: AbstractWeightedMeasure
2323
logweight :: R
2424
base :: M
25+
26+
function WeightedMeasure(ℓ, b::WeightedMeasure)
27+
WeightedMeasure(ℓ + b.logweight, b.base)
28+
end
29+
30+
function WeightedMeasure(ℓ::R, b::M) where {R,M}
31+
new{R,M}(ℓ, b)
32+
end
2533
end
2634

2735
function Base.show(io::IO, μ::WeightedMeasure)
@@ -43,7 +51,7 @@ end
4351

4452
function Base.:*(k::T, m::AbstractMeasure) where {T <: Number}
4553
logk = log(k)
46-
return WeightedMeasure{typeof(logk), typeof(m)}(logk,m)
54+
return WeightedMeasure(logk,m)
4755
end
4856

4957
Base.:*(m::AbstractMeasure, k::Real) = k * m

0 commit comments

Comments
 (0)