Skip to content

Commit c4adf67

Browse files
committed
test fixes
1 parent 9c02c84 commit c4adf67

File tree

4 files changed

+26
-19
lines changed

4 files changed

+26
-19
lines changed

src/arraydist.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,22 @@ const VectorOfUnivariate = Distributions.Product
55
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
66
return product_distribution(dists)
77
end
8+
function arraydist(dists::AbstractVector{<:Normal})
9+
m = mapvcat(mean, dists)
10+
v = mapvcat(var, dists)
11+
return MvNormal(m, v)
12+
end
13+
814
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
9-
return sum(logpdf.(dist.v, x))
15+
return sum(map((d, x) -> logpdf(d, x), dist.v, x))
1016
end
1117
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
1218
# eachcol breaks Zygote, so we need an adjoint
13-
return mapvcat((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
19+
return mapvcat(dist.v, eachcol(x)) do dist, c
20+
sum(map(c) do x
21+
logpdf(dist, c)
22+
end)
23+
end
1424
end
1525
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
1626
# Any other more efficient implementation breaks Zygote
@@ -32,7 +42,9 @@ end
3242
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
3343
# Broadcasting here breaks Tracker for some reason
3444
# A Zygote adjoint is defined for mapvcat to use broadcasting
35-
return sum(logpdf.(dist.dists, x))
45+
return sum(map(dist.dists, x) do dist, x
46+
logpdf(dist, x)
47+
end)
3648
end
3749
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
3850
return mapvcat(x -> logpdf(dist, x), x)
@@ -69,7 +81,7 @@ function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Mat
6981
return mapvcat(x -> logpdf(dist, x), x)
7082
end
7183
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
72-
f(dist, x) = sum(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))
84+
f(dist, x) = sum(mapvcat(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
7385
return pullback(f, dist, x)
7486
end
7587
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)

src/common.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
## Generic ##
22

33
_istracked(x) = false
4+
_istracked(x::TrackedArray) = false
45
_istracked(x::AbstractArray{<:TrackedReal}) = true
56
function mapvcat(f, args...)
6-
out = f.(args...,)
7+
out = map(f, args...)
78
if _istracked(out)
89
init = vcat(out[1])
910
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
@@ -12,7 +13,7 @@ function mapvcat(f, args...)
1213
end
1314
end
1415
@adjoint function mapvcat(f, args...)
15-
g(f, args...) = f.(args...)
16+
g(f, args...) = map(f, args...)
1617
return pullback(g, f, args...)
1718
end
1819

src/filldist.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,20 @@ end
4848
function _flat_logpdf(dist, x)
4949
if toflatten(dist)
5050
f, args = flatten(dist)
51-
if any(Tracker.istracked, args)
52-
return sum(f.(args..., x))
53-
else
54-
return sum(logpdf.(dist, x))
55-
end
51+
return sum(f.(args..., x))
5652
else
57-
return sum(logpdf.(dist, x))
53+
return sum(mapvcat(x) do x
54+
logpdf(dist, x)
55+
end)
5856
end
5957
end
6058
function _flat_logpdf_mat(dist, x)
6159
if toflatten(dist)
6260
f, args = flatten(dist)
63-
if any(Tracker.istracked, args)
64-
return vec(sum(f.(args..., x), dims = 1))
65-
else
66-
return vec(sum(logpdf.(dist, x), dims = 1))
67-
end
61+
return vec(sum(f.(args..., x), dims = 1))
6862
else
6963
temp = mapvcat(x -> logpdf(dist, x), x)
70-
return vec(sum(reshape(temp, size(x)), dims = 1))
64+
return vec(sum(temp, dims = 1))
7165
end
7266
end
7367

src/flatten.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const flattened_dists = [ Bernoulli,
4141
FDist,
4242
Frechet,
4343
Gamma,
44-
#GeneralizedExtremeValue,
44+
GeneralizedExtremeValue,
4545
GeneralizedPareto,
4646
Gumbel,
4747
#InverseGamma,

0 commit comments

Comments
 (0)