Skip to content

Commit 45ccbab

Browse files
authored
Merge pull request #32 from TuringLang/mt/more_tests
Many more tests and test fixes
2 parents 6b83376 + 9ea5cf1 commit 45ccbab

File tree

14 files changed

+751
-265
lines changed

14 files changed

+751
-265
lines changed

.github/workflows/CI.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- master
7+
pull_request:
8+
types: [opened, synchronize, reopened]
9+
10+
jobs:
11+
test:
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
julia-version: [1.0.5, 1.2.0, 1.3]
16+
julia-arch: [x64, x86]
17+
os: [ubuntu-latest, macOS-latest]
18+
exclude:
19+
- os: macOS-latest
20+
julia-arch: x86
21+
22+
steps:
23+
- uses: actions/[email protected]
24+
- uses: julia-actions/setup-julia@latest
25+
with:
26+
version: ${{ matrix.julia-version }}
27+
- uses: julia-actions/julia-runtest@master

.travis.yml

Lines changed: 0 additions & 23 deletions
This file was deleted.

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.3.2"
44

55
[deps]
66
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
7+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
78
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2021

2122
[compat]
2223
Combinatorics = "0.7"
24+
Compat = "3.6"
2325
DiffRules = "0.1, 1.0"
2426
Distributions = "0.22"
2527
FillArrays = "0.8"

src/DistributionsAD.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using PDMats,
88
Random,
99
Combinatorics,
1010
SpecialFunctions,
11-
StatsFuns
11+
StatsFuns,
12+
Compat
1213

1314
using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
1415
TrackedVecOrMat, track, @grad, data

src/arraydist.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real
4444
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
4545
return sum(vcatmapreduce(logpdf, dist.dists, x))
4646
end
47+
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
48+
return vcatmapreduce(x -> logpdf(dist, x), x)
49+
end
50+
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
51+
return vcatmapreduce(x -> logpdf(dist, x), x)
52+
end
4753
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
4854
return rand.(Ref(rng), dist.dists)
4955
end
@@ -66,6 +72,12 @@ function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Re
6672
# eachcol breaks Zygote, so we define an adjoint
6773
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
6874
end
75+
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
76+
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
77+
end
78+
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
79+
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
80+
end
6981
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
7082
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
7183
return pullback(f, dist, x)

src/common.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ if VERSION < v"1.1"
44
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
55
end
66

7-
Base.one(::Irrational) = true
8-
97
function vcatmapreduce(f, args...)
108
init = vcat(f(first.(args)...,))
119
zipped_args = zip(args...,)
@@ -14,7 +12,7 @@ function vcatmapreduce(f, args...)
1412
end
1513
end
1614
@adjoint function vcatmapreduce(f, args...)
17-
g(f, args...) = f.(args...,)
15+
g(f, args...) = f.(args...)
1816
return pullback(g, f, args...)
1917
end
2018

src/filldist.jl

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Univariate
22

3+
Tracker.dual(x::Int, p) = x
4+
35
const FillVectorOfUnivariate{
46
S <: ValueSupport,
57
T <: UnivariateDistribution{S},
@@ -46,15 +48,23 @@ end
4648
function _flat_logpdf(dist, x)
4749
if toflatten(dist)
4850
f, args = flatten(dist)
49-
return sum(f.(args..., x))
51+
if any(Tracker.istracked, args)
52+
return sum(f.(args..., x))
53+
else
54+
return sum(logpdf.(dist, x))
55+
end
5056
else
5157
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
5258
end
5359
end
5460
function _flat_logpdf_mat(dist, x)
5561
if toflatten(dist)
5662
f, args = flatten(dist)
57-
return vec(sum(f.(args..., x), dims = 1))
63+
if any(Tracker.istracked, args)
64+
return vec(sum(f.(args..., x), dims = 1))
65+
else
66+
return vec(sum(logpdf.(dist, x), dims = 1))
67+
end
5868
else
5969
temp = vcatmapreduce(x -> logpdf(dist, x), x)
6070
return vec(sum(reshape(temp, size(x)), dims = 1))
@@ -74,7 +84,7 @@ function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:
7484
return _flat_logpdf(dist.dists.value, x)
7585
end
7686
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate)
77-
return rand(rng, dist.dists.value, length.(dist.dists.axes))
87+
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
7888
end
7989

8090
# Multivariate
@@ -94,18 +104,18 @@ function Distributions.logpdf(
94104
)
95105
return _logpdf(dist, x)
96106
end
97-
@adjoint function Distributions.logpdf(
107+
function _logpdf(
98108
dist::FillVectorOfMultivariate,
99109
x::AbstractMatrix{<:Real},
100110
)
101-
return pullback(_logpdf, dist, x)
111+
return sum(logpdf(dist.dists.value, x))
102112
end
103-
function _logpdf(
113+
@adjoint function Distributions.logpdf(
104114
dist::FillVectorOfMultivariate,
105115
x::AbstractMatrix{<:Real},
106116
)
107-
return sum(logpdf(dist.dists.value, x))
117+
return pullback(_logpdf, dist, x)
108118
end
109119
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
110-
return rand(rng, dist.dists.value, length.(dist.dists.axes))
120+
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
111121
end

src/flatten.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ const flattened_dists = [ Bernoulli,
2626
NegativeBinomial,
2727
Poisson,
2828
Skellam,
29-
PoissonBinomial,
3029
Arcsine,
3130
Beta,
3231
BetaPrime,
@@ -42,10 +41,10 @@ const flattened_dists = [ Bernoulli,
4241
FDist,
4342
Frechet,
4443
Gamma,
45-
GeneralizedExtremeValue,
44+
#GeneralizedExtremeValue,
4645
GeneralizedPareto,
4746
Gumbel,
48-
InverseGamma,
47+
#InverseGamma,
4948
InverseGaussian,
5049
Kolmogorov,
5150
Laplace,
@@ -55,17 +54,17 @@ const flattened_dists = [ Bernoulli,
5554
LogitNormal,
5655
LogNormal,
5756
Normal,
58-
NormalCanon,
59-
NormalInverseGaussian,
57+
#NormalCanon,
58+
#NormalInverseGaussian,
6059
Pareto,
6160
PGeneralizedGaussian,
6261
Rayleigh,
6362
SymTriangularDist,
6463
TDist,
6564
TriangularDist,
6665
Triweight,
67-
Categorical,
68-
Truncated,
66+
#Truncated,
67+
#VonMises,
6968
]
7069
for T in flattened_dists
7170
@eval toflatten(::$T) = true

src/matrixvariate.jl

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
## MatrixBeta
2+
3+
function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
4+
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
5+
end
6+
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
7+
f(d, X) = map(x -> logpdf(d, x), X)
8+
return pullback(f, d, X)
9+
end
10+
111
# Adapted from Distributions.jl
212

313
## Wishart
@@ -12,6 +22,13 @@ end
1222

1323
#### Constructors
1424

25+
function TuringWishart(d::Wishart)
26+
d = TuringWishart(d.df, getchol(d.S), d.c0)
27+
end
28+
getchol(p::PDMat) = p.chol
29+
getchol(p::PDiagMat) = Diagonal(map(sqrt, p.diag))
30+
getchol(p::ScalMat) = Diagonal(fill(sqrt(p.value), p.dim))
31+
1532
function TuringWishart(df::T, S::AbstractMatrix) where {T <: Real}
1633
p = size(S, 1)
1734
df > p - 1 || error("dpf should be greater than dim - 1.")
@@ -66,7 +83,7 @@ end
6683
function Distributions.entropy(d::TuringWishart)
6784
p = Distributions.dim(d)
6885
df = d.df
69-
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
86+
d.c0 - 0.5 * (df - p - 1) * Distributions.meanlogdet(d) + 0.5 * df * p
7087
end
7188

7289
# Gupta/Nagar (1999) Theorem 3.3.15.i
@@ -82,12 +99,24 @@ end
8299

83100
#### Evaluation
84101

102+
function Distributions.logpdf(d::Wishart, X::TrackedMatrix)
103+
return logpdf(TuringWishart(d), X)
104+
end
105+
function Distributions.logpdf(d::Wishart, X::AbstractArray{<:TrackedMatrix})
106+
return logpdf(TuringWishart(d), X)
107+
end
85108
function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
86109
df = d.df
87110
p = Distributions.dim(d)
88111
Xcf = cholesky(X)
89112
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
90113
end
114+
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
115+
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
116+
end
117+
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
118+
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
119+
end
91120

92121
#### Sampling
93122
function Distributions._rand!(rng::AbstractRNG, d::TuringWishart, A::AbstractMatrix)
@@ -128,6 +157,13 @@ end
128157

129158
#### Constructors
130159

160+
function TuringInverseWishart(d::InverseWishart)
161+
d = TuringInverseWishart(d.df, getmatrix(d.Ψ), d.c0)
162+
end
163+
getmatrix(p::PDMat) = p.mat
164+
getmatrix(p::PDiagMat) = Diagonal(p.diag)
165+
getmatrix(p::ScalMat) = Diagonal(fill(p.value, p.dim))
166+
131167
function TuringInverseWishart(df::T, Ψ::AbstractMatrix) where T<:Real
132168
p = size(Ψ, 1)
133169
df > p - 1 || error("df should be greater than dim - 1.")
@@ -182,6 +218,12 @@ end
182218

183219
#### Evaluation
184220

221+
function Distributions.logpdf(d::InverseWishart, X::TrackedMatrix)
222+
return logpdf(TuringInverseWishart(d), X)
223+
end
224+
function Distributions.logpdf(d::InverseWishart, X::AbstractArray{<:TrackedMatrix})
225+
return logpdf(TuringInverseWishart(d), X)
226+
end
185227
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real})
186228
p = Distributions.dim(d)
187229
df = d.df
@@ -190,7 +232,12 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
190232
Ψ = d.S
191233
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
192234
end
193-
235+
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
236+
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
237+
end
238+
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
239+
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
240+
end
194241

195242
#### Sampling
196243

src/multivariate.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ struct TuringDiagMvNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: Continuous
116116
σ::Tσ
117117
end
118118

119-
Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ)
120-
Distributions.dim(d::TuringDiagMvNormal) = length(d.m)
121119
Base.length(d::TuringDiagMvNormal) = length(d.m)
122120
Base.size(d::TuringDiagMvNormal) = (length(d),)
123121
Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
@@ -164,6 +162,20 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
164162
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
165163
end
166164

165+
for T in (:TrackedVector, :TrackedMatrix)
166+
@eval begin
167+
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
168+
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
169+
end
170+
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
171+
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
172+
end
173+
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
174+
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
175+
end
176+
end
177+
end
178+
167179
import StatsBase: entropy
168180
function entropy(d::TuringDiagMvNormal)
169181
T = eltype(d.σ)

0 commit comments

Comments
 (0)