Skip to content

Commit c93d8e7

Browse files
authored
Merge pull request #13 from cscherrer/master
Add for.jl
2 parents 5ac4028 + aa785ac commit c93d8e7

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

src/for.jl

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using Distributions
2+
import Distributions.logpdf
3+
using Base.Cartesian
4+
5+
6+
export logpdf
7+
export rand
8+
9+
export For
10+
struct For{F,T,D,X}
11+
f :: F
12+
θ :: T
13+
end
14+
15+
#########################################################
16+
# T <: NTuple{N,J} where {J <: Integer}
17+
#########################################################
18+
19+
For(f, θ::J...) where {J <: Integer} = For(f,θ)
20+
21+
function For(f::F, θ::T) where {F, N, J <: Integer, T <: NTuple{N,J}}
22+
d = f.(ones(Int, N)...)
23+
D = typeof(d)
24+
X = eltype(d)
25+
For{F, NTuple{N,J}, D, X}(f,θ)
26+
end
27+
28+
@inline function logpdf(d::For{F,T,D,X1},xs::AbstractArray{X2,N}) where {F, N, J <: Integer, T <: NTuple{N,J}, D, X1, X2 <: X1}
29+
s = 0.0
30+
@inbounds @simd for θ in CartesianIndices(d.θ)
31+
s += logpdf(d.f(Tuple(θ)...), xs[θ])
32+
end
33+
s
34+
end
35+
36+
function Base.rand(dist::For)
37+
map(CartesianIndices(dist.θ)) do I
38+
(rand dist.f)(Tuple(I)...)
39+
end
40+
end
41+
42+
#########################################################
43+
# T <: NTuple{N,J} where {J <: AbstractUnitRange}
44+
#########################################################
45+
46+
For(f, θ::J...) where {J <: AbstractUnitRange} = For(f,θ)
47+
48+
function For(f::F, θ::T) where {F, N, J <: AbstractRange, T <: NTuple{N,J}}
49+
d = f.(ones(Int, N)...)
50+
D = typeof(d)
51+
X = eltype(d)
52+
For{F, NTuple{N,J}, D, X}(f,θ)
53+
end
54+
55+
56+
@inline function logpdf(d::For{F,T,D,X1},xs::AbstractArray{X2,N}) where {F, N, J <: AbstractRange, T <: NTuple{N,J}, D, X1, X2 <: X1}
57+
s = 0.0
58+
@inbounds @simd for θ in CartesianIndices(d.θ)
59+
s += logpdf(d.f(Tuple(θ)...), xs[θ])
60+
end
61+
s
62+
end
63+
64+
65+
function Base.rand(dist::For{F,T}) where {F, N, J <: AbstractRange, T <: NTuple{N,J}}
66+
map(CartesianIndices(dist.θ)) do I
67+
(rand dist.f)(Tuple(I)...)
68+
end
69+
end
70+
71+
#########################################################
72+
# T <: Base.Generator
73+
#########################################################
74+
75+
function For(f::F, θ::T) where {F, T <: Base.Generator}
76+
d = f.f.iter[1]))
77+
D = typeof(d)
78+
X = eltype(d)
79+
For{F, T, D, X}(f,θ)
80+
end
81+
82+
83+
@inline function logpdf(d :: For{F,T}, x) where {F,T <: Base.Generator}
84+
s = 0.0
85+
for (θj, xj) in zip(d.θ, x)
86+
s += logpdf(d.f(θj), xj)
87+
end
88+
s
89+
end
90+
91+
@inline function rand(d :: For{F,T,D,X}) where {F,T <: Base.Generator, D, X}
92+
rand.(Base.Generator(d.θ.f, d.θ.iter))
93+
end
94+
95+
#########################################################
96+
97+

0 commit comments

Comments
 (0)