@@ -5,12 +5,22 @@ const VectorOfUnivariate = Distributions.Product
5
5
function arraydist (dists:: AbstractVector{<:UnivariateDistribution} )
6
6
return product_distribution (dists)
7
7
end
8
+ function arraydist (dists:: AbstractVector{<:Normal} )
9
+ m = mapvcat (mean, dists)
10
+ v = mapvcat (var, dists)
11
+ return MvNormal (m, v)
12
+ end
13
+
8
14
function Distributions. logpdf (dist:: VectorOfUnivariate , x:: AbstractVector{<:Real} )
9
- return sum (logpdf .( dist. v, x))
15
+ return sum (map ((d, x) -> logpdf (d, x), dist. v, x))
10
16
end
11
17
function Distributions. logpdf (dist:: VectorOfUnivariate , x:: AbstractMatrix{<:Real} )
12
18
# eachcol breaks Zygote, so we need an adjoint
13
- return mapvcat ((dist, c) -> logpdf .(dist, c), dist. v, eachcol (x))
19
+ return mapvcat (dist. v, eachcol (x)) do dist, c
20
+ sum (map (c) do x
21
+ logpdf (dist, c)
22
+ end )
23
+ end
14
24
end
15
25
@adjoint function Distributions. logpdf (dist:: VectorOfUnivariate , x:: AbstractMatrix{<:Real} )
16
26
# Any other more efficient implementation breaks Zygote
32
42
function Distributions. logpdf (dist:: MatrixOfUnivariate , x:: AbstractMatrix{<:Real} )
33
43
# Broadcasting here breaks Tracker for some reason
34
44
# A Zygote adjoint is defined for mapvcat to use broadcasting
35
- return sum (logpdf .(dist. dists, x))
45
+ return sum (map (dist. dists, x) do dist, x
46
+ logpdf (dist, x)
47
+ end )
36
48
end
37
49
function Distributions. logpdf (dist:: MatrixOfUnivariate , x:: AbstractArray{<:AbstractMatrix{<:Real}} )
38
50
return mapvcat (x -> logpdf (dist, x), x)
@@ -69,7 +81,7 @@ function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Mat
69
81
return mapvcat (x -> logpdf (dist, x), x)
70
82
end
71
83
@adjoint function Distributions. logpdf (dist:: VectorOfMultivariate , x:: AbstractMatrix{<:Real} )
72
- f (dist, x) = sum (i -> logpdf (dist. dists[i], view (x, :, i)), 1 : size (x, 2 ))
84
+ f (dist, x) = sum (mapvcat ( i -> logpdf (dist. dists[i], view (x, :, i)), 1 : size (x, 2 ) ))
73
85
return pullback (f, dist, x)
74
86
end
75
87
function Distributions. rand (rng:: Random.AbstractRNG , dist:: VectorOfMultivariate )
0 commit comments