Skip to content

Commit cb0939e

Browse files
committed
Bugfixes and tests
1 parent 06d273b commit cb0939e

File tree

2 files changed

+35
-50
lines changed

2 files changed

+35
-50
lines changed

src/combinators/product.jl

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ function logdensity(d::ProductMeasure{F,T}, x::Tuple) where {F,T<:Tuple}
7272
mapreduce(logdensity, +, d.f.(d.pars), x)
7373
end
7474

75+
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:Tuple}
76+
rand.(d.pars)
77+
end
78+
7579
###############################################################################
7680
# I <: AbstractArray
7781

@@ -100,6 +104,11 @@ function Base.show(io::IO, d::ProductMeasure{F,I}) where {F, I<:CartesianIndices
100104
print(io, ")")
101105
end
102106

107+
108+
# function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:CartesianIndices}
109+
110+
# end
111+
103112
###############################################################################
104113
# I <: Base.Generator
105114

@@ -113,6 +122,15 @@ function logdensity(d::ProductMeasure{F,I}, x) where {F, I<:Base.Generator}
113122
end
114123

115124

125+
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:Base.Generator}
126+
mar = marginals(d)
127+
elT = typeof(rand(rng, T, first(mar)))
128+
129+
sz = size(mar)
130+
r = ResettableRNG(rng, rand(rng, UInt))
131+
Base.Generator(s -> rand(r, d.pars.f(s)), d.pars.iter)
132+
end
133+
116134
function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T,F,I<:Base.Generator}
117135
mar = marginals(d)
118136
elT = typeof(rand(rng, T, first(mar)))
@@ -123,7 +141,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, d::ProductMeasure{F,I}) where {T
123141
end
124142

125143
@propagate_inbounds function Random.rand!(rng::AbstractRNG, d::ProductMeasure, x::AbstractArray)
126-
T = float(eltype(x))
144+
# TODO: Generalize this
145+
T = Float64
127146
for(j,m) in zip(eachindex(x), marginals(d))
128147
@inbounds x[j] = rand(rng, T, m)
129148
end
@@ -135,41 +154,6 @@ end
135154

136155

137156

138-
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,1}}
139-
# data = marginals(d)
140-
# @boundscheck size(data) == size(x) || throw(BoundsError)
141-
# @tullio s = logdensity(data[i], x[i])
142-
# s
143-
# end
144-
145-
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,2}}
146-
# data = marginals(d)
147-
# @boundscheck size(data) == size(x) || throw(BoundsError)
148-
# @tullio s = @inbounds logdensity(data[i,j], x[i,j])
149-
# s
150-
# end
151-
152-
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,3}}
153-
# data = marginals(d)
154-
# @boundscheck size(data) == size(x) || throw(BoundsError)
155-
# @tullio s = @inbounds logdensity(data[i,j,k], x[i,j,k])
156-
# s
157-
# end
158-
159-
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,4}}
160-
# data = marginals(d)
161-
# @boundscheck size(data) == size(x) || throw(BoundsError)
162-
# @tullio s = @inbounds logdensity(data[i,j,k,l], x[i,j,k,l])
163-
# s
164-
# end
165-
166-
# @propagate_inbounds function MeasureBase.logdensity(d::ProductMeasure{A}, x) where{T, A<:AbstractArray{T,5}}
167-
# data = marginals(d)
168-
# @boundscheck size(data) == size(x) || throw(BoundsError)
169-
# @tullio s = @inbounds logdensity(data[i,j,k,l,m], x[i,j,k,l,m])
170-
# s
171-
# end
172-
173157
export rand!
174158
using Random: rand!, GLOBAL_RNG, AbstractRNG
175159

test/runtests.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ test_measures = [
5050
Dirac(π)
5151
Lebesgue(ℝ)
5252
Dirac(0.0) + Lebesgue(ℝ)
53+
SpikeMixture(Lebesgue(ℝ), 2)
5354
# Normal() ⊙ Cauchy()
5455
]
5556

5657
testbroken_measures = [
57-
SpikeMixture(Lebesgue(ℝ), 2)
5858
# InverseGamma(2) # Not defined yet
5959
# MvNormal(I(3)) # Entirely broken for now
6060
CountingMeasure(Float64)
@@ -124,19 +124,20 @@ end
124124
end
125125
end
126126

127-
# @testset "For" begin
128-
# FORDISTS = [
129-
# For(1:10) do j Normal(μ=j) end
130-
# For(4,3) do μ,σ Normal(μ,σ) end
131-
# For(1:4, 1:4) do μ,σ Normal(μ,σ) end
132-
# For(eachrow(rand(4,2))) do x Normal(x[1], x[2]) end
133-
# For(rand(4), rand(4)) do μ,σ Normal(μ,σ) end
134-
# ]
135-
136-
# for d in FORDISTS
137-
# @test logdensity(d, rand(d)) isa Float64
138-
# end
139-
# end
127+
@testset "For" begin
128+
FORDISTS = [
129+
For(1:10) do j Dirac(j) end
130+
For(4,3) do i,j Dirac(i) Dirac(j) end
131+
For(1:4, 1:4) do i,j Dirac(i) Dirac(j) end
132+
For(eachrow(rand(4,2))) do x Dirac(x[1]) Dirac(x[2]) end
133+
For(rand(4), rand(4)) do i,j Dirac(i) Dirac(j) end
134+
]
135+
136+
for d in FORDISTS
137+
@info "testing $d"
138+
@test logdensity(d, rand(d)) isa Float64
139+
end
140+
end
140141

141142
# import MeasureBase.:⋅
142143
# function ⋅(μ::Normal, kernel)

0 commit comments

Comments
 (0)