Skip to content

Commit 91fb311

Browse files
authored
Merge pull request #26 from TuringLang/mt/missing_MvNormal_constr
Missing MvNormal and MvLogNormal constructors and Diagonal covariance tests
2 parents 4bce688 + 6711168 commit 91fb311

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

src/multivariate.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ function MvNormal(
103103
)
104104
return TuringMvNormal(m, D)
105105
end
106+
function MvNormal(
107+
m::TrackedVector{<:Real},
108+
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
109+
)
110+
return TuringMvNormal(m, D)
111+
end
106112

107113
# dense mean, diagonal covariance
108114
MvNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvNormal(m, σ)
@@ -211,6 +217,18 @@ function MvLogNormal(
211217
)
212218
return TuringMvLogNormal(TuringMvNormal(m, D))
213219
end
220+
function MvLogNormal(
221+
m::TrackedVector{<:Real},
222+
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
223+
)
224+
return TuringMvLogNormal(TuringMvNormal(m, D))
225+
end
226+
function MvLogNormal(
227+
m::AbstractVector{<:Real},
228+
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
229+
)
230+
return MvLogNormal(MvNormal(m, D))
231+
end
214232

215233
# dense mean, diagonal covariance
216234
MvLogNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, σ))

test/distributions.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,26 +176,34 @@ separator()
176176
# Vector case
177177
DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec),
178178
DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec),
179+
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
179180
DistSpec(:MvNormal, (mean, cov_num), norm_val_vec),
180181
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec),
181182
DistSpec(:MvNormal, (cov_mat,), norm_val_vec),
182183
DistSpec(:MvNormal, (cov_vec,), norm_val_vec),
184+
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_vec),
183185
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec),
184186
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec),
185187
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec),
188+
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
186189
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec),
187190
DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec),
188191
DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec),
192+
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec),
189193
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec),
190194
# Matrix case
191195
DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat),
196+
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
192197
DistSpec(:MvNormal, (mean, cov_num), norm_val_mat),
193198
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat),
194199
DistSpec(:MvNormal, (cov_vec,), norm_val_mat),
200+
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_mat),
195201
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat),
196202
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat),
203+
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
197204
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat),
198205
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
206+
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
199207
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
200208
]
201209

test/test_utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct DistSpec{Tθ<:Tuple, Tx}
1010
end
1111

1212
vectorize(v::Number) = [v]
13+
vectorize(v::Diagonal) = v.diag
1314
vectorize(v) = vec(v)
1415
pack(vals...) = reduce(vcat, vectorize.(vals))
1516
@generated function unpack(x, vals...)
@@ -22,6 +23,9 @@ pack(vals...) = reduce(vcat, vectorize.(vals))
2223
elseif T <: Vector
2324
push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1]))
2425
ind = :($ind + length(vals[$i]))
26+
elseif T <: Diagonal
27+
push!(unpacked, :(Diagonal(x[$ind:$ind+size(vals[$i],1)-1])))
28+
ind = :($ind + size(vals[$i],1))
2529
elseif T <: Matrix
2630
push!(unpacked, :(reshape(x[$ind:($ind+length(vals[$i])-1)], size(vals[$i]))))
2731
ind = :($ind + length(vals[$i]))

0 commit comments

Comments
 (0)