Skip to content

Commit 9c02c84

Browse files
committed
perf fixes
1 parent 0f29efb commit 9c02c84

File tree

6 files changed

+48
-33
lines changed

6 files changed

+48
-33
lines changed

src/arraydist.jl

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,15 @@
22

33
const VectorOfUnivariate = Distributions.Product
44

5-
function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
6-
means = mean.(dists)
7-
vars = var.(dists)
8-
return MvNormal(means, vars)
9-
end
10-
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
11-
means = vcatmapreduce(mean, dists)
12-
vars = vcatmapreduce(var, dists)
13-
return MvNormal(means, vars)
14-
end
155
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
166
return product_distribution(dists)
177
end
188
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
19-
return sum(vcatmapreduce(logpdf, dist.v, x))
9+
return sum(logpdf.(dist.v, x))
2010
end
2111
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
2212
# eachcol breaks Zygote, so we need an adjoint
23-
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
13+
return mapvcat((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
2414
end
2515
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
2616
# Any other more efficient implementation breaks Zygote
@@ -41,14 +31,14 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
4131
end
4232
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
4333
# Broadcasting here breaks Tracker for some reason
44-
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
45-
return sum(vcatmapreduce(logpdf, dist.dists, x))
34+
# A Zygote adjoint is defined for mapvcat to use broadcasting
35+
return sum(logpdf.(dist.dists, x))
4636
end
4737
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
48-
return vcatmapreduce(x -> logpdf(dist, x), x)
38+
return mapvcat(x -> logpdf(dist, x), x)
4939
end
5040
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
51-
return vcatmapreduce(x -> logpdf(dist, x), x)
41+
return mapvcat(x -> logpdf(dist, x), x)
5242
end
5343
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
5444
return rand.(Ref(rng), dist.dists)
@@ -70,16 +60,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
7060
end
7161
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
7262
# eachcol breaks Zygote, so we define an adjoint
73-
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
63+
return sum(logpdf.(dist.dists, eachcol(x)))
7464
end
7565
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
76-
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
66+
return mapvcat(x -> logpdf(dist, x), x)
7767
end
7868
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
79-
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
69+
return mapvcat(x -> logpdf(dist, x), x)
8070
end
8171
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
82-
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
72+
f(dist, x) = sum(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))
8373
return pullback(f, dist, x)
8474
end
8575
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)

src/common.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
## Generic ##
22

3-
function vcatmapreduce(f, args...)
4-
init = vcat(f(first.(args)...,))
5-
zipped_args = zip(args...,)
6-
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
7-
f(zarg...,)
3+
_istracked(x) = false
4+
_istracked(x::AbstractArray{<:TrackedReal}) = true
5+
function mapvcat(f, args...)
6+
out = f.(args...,)
7+
if _istracked(out)
8+
init = vcat(out[1])
9+
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
10+
else
11+
return out
812
end
913
end
10-
@adjoint function vcatmapreduce(f, args...)
14+
@adjoint function mapvcat(f, args...)
1115
g(f, args...) = f.(args...)
1216
return pullback(g, f, args...)
1317
end

src/filldist.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function _flat_logpdf(dist, x)
5454
return sum(logpdf.(dist, x))
5555
end
5656
else
57-
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
57+
return sum(logpdf.(dist, x))
5858
end
5959
end
6060
function _flat_logpdf_mat(dist, x)
@@ -66,7 +66,7 @@ function _flat_logpdf_mat(dist, x)
6666
return vec(sum(logpdf.(dist, x), dims = 1))
6767
end
6868
else
69-
temp = vcatmapreduce(x -> logpdf(dist, x), x)
69+
temp = mapvcat(x -> logpdf(dist, x), x)
7070
return vec(sum(reshape(temp, size(x)), dims = 1))
7171
end
7272
end

src/flatten.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ const flattened_dists = [ Bernoulli,
6363
TDist,
6464
TriangularDist,
6565
Triweight,
66+
TuringUniform,
6667
#Truncated,
6768
#VonMises,
6869
]

src/matrixvariate.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## MatrixBeta
22

33
function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
4-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
4+
return mapvcat(x -> logpdf(d, x), X)
55
end
66
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
77
f(d, X) = map(x -> logpdf(d, x), X)
@@ -112,10 +112,10 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
112112
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
113113
end
114114
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
115-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
115+
return mapvcat(x -> logpdf(d, x), X)
116116
end
117117
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
118-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
118+
return mapvcat(x -> logpdf(d, x), X)
119119
end
120120

121121
#### Sampling
@@ -233,10 +233,10 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
233233
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
234234
end
235235
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
236-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
236+
return mapvcat(x -> logpdf(d, x), X)
237237
end
238238
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
239-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
239+
return mapvcat(x -> logpdf(d, x), X)
240240
end
241241

242242
#### Sampling

src/univariate.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ function TuringUniform(a::Real, b::Real)
1313
return TuringUniform{T}(T(a), T(b))
1414
end
1515
Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x)
16+
Base.minimum(d::TuringUniform) = d.a
17+
Base.maximum(d::TuringUniform) = d.b
1618

1719
Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b)
1820
Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
@@ -348,3 +350,21 @@ function Base.convert(
348350
DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false)
349351
end
350352

353+
# Fix SubArray support
354+
function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
355+
vs::Ts,
356+
ps::Ps;
357+
check_args=true,
358+
) where {T<:Real, P<:Real, Ts<:AbstractVector{T}, Ps<:SubArray{P, 1}}
359+
cps = ps[:]
360+
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
361+
end
362+
363+
function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
364+
vs::Ts,
365+
ps::Ps;
366+
check_args=true,
367+
) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:TrackedArray{P, 1, <:SubArray{P, 1}}}
368+
cps = ps[:]
369+
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
370+
end

0 commit comments

Comments
 (0)