Skip to content

Commit 2f9d942

Browse files
authored
Merge pull request #37 from TuringLang/mt/perf_fixes
Minor performance and bug fixes (lessons learnt from TuringExamples)
2 parents 0f29efb + a0d96e0 commit 2f9d942

File tree

13 files changed

+370
-243
lines changed

13 files changed

+370
-243
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: ForwardDiff and Tracker tests
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
types: [opened, synchronize, reopened]
9+
10+
jobs:
11+
test:
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
julia-version: [1.0.5, 1.2.0, 1.3]
16+
julia-arch: [x64, x86]
17+
os: [ubuntu-latest, macOS-latest]
18+
exclude:
19+
- os: macOS-latest
20+
julia-arch: x86
21+
22+
steps:
23+
- uses: actions/[email protected]
24+
- uses: julia-actions/setup-julia@latest
25+
with:
26+
version: ${{ matrix.julia-version }}
27+
- uses: julia-actions/julia-runtest@master
28+
env:
29+
STAGE: ForwardDiff_Tracker

.github/workflows/CI.yml renamed to .github/workflows/Others.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: CI
1+
name: Other tests
22

33
on:
44
push:
@@ -25,3 +25,5 @@ jobs:
2525
with:
2626
version: ${{ matrix.julia-version }}
2727
- uses: julia-actions/julia-runtest@master
28+
env:
29+
STAGE: Others

.github/workflows/Zygote.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Zygote tests
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
types: [opened, synchronize, reopened]
9+
10+
jobs:
11+
test:
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
julia-version: [1.0.5, 1.2.0, 1.3]
16+
julia-arch: [x64, x86]
17+
os: [ubuntu-latest, macOS-latest]
18+
exclude:
19+
- os: macOS-latest
20+
julia-arch: x86
21+
22+
steps:
23+
- uses: actions/[email protected]
24+
- uses: julia-actions/setup-julia@latest
25+
with:
26+
version: ${{ matrix.julia-version }}
27+
- uses: julia-actions/julia-runtest@master
28+
env:
29+
STAGE: Zygote

src/arraydist.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,25 @@
22

33
const VectorOfUnivariate = Distributions.Product
44

5-
function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
6-
means = mean.(dists)
7-
vars = var.(dists)
8-
return MvNormal(means, vars)
9-
end
10-
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
11-
means = vcatmapreduce(mean, dists)
12-
vars = vcatmapreduce(var, dists)
13-
return MvNormal(means, vars)
14-
end
155
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
166
return product_distribution(dists)
177
end
8+
function arraydist(dists::AbstractVector{<:Normal})
9+
m = mapvcat(mean, dists)
10+
v = mapvcat(var, dists)
11+
return MvNormal(m, v)
12+
end
13+
1814
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
19-
return sum(vcatmapreduce(logpdf, dist.v, x))
15+
return sum(map((d, x) -> logpdf(d, x), dist.v, x))
2016
end
2117
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
2218
# eachcol breaks Zygote, so we need an adjoint
23-
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
19+
return mapvcat(dist.v, eachcol(x)) do dist, c
20+
sum(map(c) do x
21+
logpdf(dist, x)
22+
end)
23+
end
2424
end
2525
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
2626
# Any other more efficient implementation breaks Zygote
@@ -41,14 +41,16 @@ function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
4141
end
4242
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
4343
# Broadcasting here breaks Tracker for some reason
44-
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
45-
return sum(vcatmapreduce(logpdf, dist.dists, x))
44+
# A Zygote adjoint is defined for mapvcat to use broadcasting
45+
return sum(map(dist.dists, x) do dist, x
46+
logpdf(dist, x)
47+
end)
4648
end
4749
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
48-
return vcatmapreduce(x -> logpdf(dist, x), x)
50+
return mapvcat(x -> logpdf(dist, x), x)
4951
end
5052
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
51-
return vcatmapreduce(x -> logpdf(dist, x), x)
53+
return mapvcat(x -> logpdf(dist, x), x)
5254
end
5355
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
5456
return rand.(Ref(rng), dist.dists)
@@ -70,16 +72,16 @@ function arraydist(dists::AbstractVector{<:MultivariateDistribution})
7072
end
7173
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
7274
# eachcol breaks Zygote, so we define an adjoint
73-
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
75+
return sum(logpdf.(dist.dists, eachcol(x)))
7476
end
7577
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
76-
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
78+
return mapvcat(x -> logpdf(dist, x), x)
7779
end
7880
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
79-
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
81+
return mapvcat(x -> logpdf(dist, x), x)
8082
end
8183
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
82-
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
84+
f(dist, x) = sum(mapvcat(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
8385
return pullback(f, dist, x)
8486
end
8587
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)

src/common.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
## Generic ##
22

3-
function vcatmapreduce(f, args...)
4-
init = vcat(f(first.(args)...,))
5-
zipped_args = zip(args...,)
6-
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
7-
f(zarg...,)
3+
_istracked(x) = false
4+
_istracked(x::TrackedArray) = false
5+
_istracked(x::AbstractArray{<:TrackedReal}) = true
6+
function mapvcat(f, args...)
7+
out = map(f, args...)
8+
if _istracked(out)
9+
init = vcat(out[1])
10+
return reshape(reduce(vcat, drop(out, 1); init = init), size(out))
11+
else
12+
return out
813
end
914
end
10-
@adjoint function vcatmapreduce(f, args...)
11-
g(f, args...) = f.(args...)
15+
@adjoint function mapvcat(f, args...)
16+
g(f, args...) = map(f, args...)
1217
return pullback(g, f, args...)
1318
end
1419

src/filldist.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,26 +48,20 @@ end
4848
function _flat_logpdf(dist, x)
4949
if toflatten(dist)
5050
f, args = flatten(dist)
51-
if any(Tracker.istracked, args)
52-
return sum(f.(args..., x))
53-
else
54-
return sum(logpdf.(dist, x))
55-
end
51+
return sum(f.(args..., x))
5652
else
57-
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
53+
return sum(mapvcat(x) do x
54+
logpdf(dist, x)
55+
end)
5856
end
5957
end
6058
function _flat_logpdf_mat(dist, x)
6159
if toflatten(dist)
6260
f, args = flatten(dist)
63-
if any(Tracker.istracked, args)
64-
return vec(sum(f.(args..., x), dims = 1))
65-
else
66-
return vec(sum(logpdf.(dist, x), dims = 1))
67-
end
61+
return vec(sum(f.(args..., x), dims = 1))
6862
else
69-
temp = vcatmapreduce(x -> logpdf(dist, x), x)
70-
return vec(sum(reshape(temp, size(x)), dims = 1))
63+
temp = mapvcat(x -> logpdf(dist, x), x)
64+
return vec(sum(temp, dims = 1))
7165
end
7266
end
7367

src/flatten.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ const flattened_dists = [ Bernoulli,
4141
FDist,
4242
Frechet,
4343
Gamma,
44-
#GeneralizedExtremeValue,
44+
GeneralizedExtremeValue,
4545
GeneralizedPareto,
4646
Gumbel,
4747
#InverseGamma,
@@ -63,6 +63,7 @@ const flattened_dists = [ Bernoulli,
6363
TDist,
6464
TriangularDist,
6565
Triweight,
66+
TuringUniform,
6667
#Truncated,
6768
#VonMises,
6869
]

src/matrixvariate.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
## MatrixBeta
22

33
function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
4-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
4+
return mapvcat(x -> logpdf(d, x), X)
55
end
66
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
77
f(d, X) = map(x -> logpdf(d, x), X)
@@ -112,10 +112,10 @@ function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
112112
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
113113
end
114114
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
115-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
115+
return mapvcat(x -> logpdf(d, x), X)
116116
end
117117
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
118-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
118+
return mapvcat(x -> logpdf(d, x), X)
119119
end
120120

121121
#### Sampling
@@ -233,10 +233,10 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
233233
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
234234
end
235235
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
236-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
236+
return mapvcat(x -> logpdf(d, x), X)
237237
end
238238
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
239-
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
239+
return mapvcat(x -> logpdf(d, x), X)
240240
end
241241

242242
#### Sampling

src/univariate.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ function TuringUniform(a::Real, b::Real)
1313
return TuringUniform{T}(T(a), T(b))
1414
end
1515
Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x)
16+
Base.minimum(d::TuringUniform) = d.a
17+
Base.maximum(d::TuringUniform) = d.b
1618

1719
Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b)
1820
Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
@@ -348,3 +350,21 @@ function Base.convert(
348350
DiscreteNonParametric{T,P,Ts,Ps}(support(d), probs(d), check_args=false)
349351
end
350352

353+
# Fix SubArray support
354+
function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
355+
vs::Ts,
356+
ps::Ps;
357+
check_args=true,
358+
) where {T<:Real, P<:Real, Ts<:AbstractVector{T}, Ps<:SubArray{P, 1}}
359+
cps = ps[:]
360+
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
361+
end
362+
363+
function Distributions.DiscreteNonParametric{T,P,Ts,Ps}(
364+
vs::Ts,
365+
ps::Ps;
366+
check_args=true,
367+
) where {T<:Real,P<:Real,Ts<:AbstractVector{T},Ps<:TrackedArray{P, 1, <:SubArray{P, 1}}}
368+
cps = ps[:]
369+
return DiscreteNonParametric{T,P,Ts,typeof(cps)}(vs, cps; check_args = check_args)
370+
end

0 commit comments

Comments
 (0)