Skip to content

Commit 6711168

Browse files
committed
test Diagonal covariance
1 parent e958f72 commit 6711168

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

test/distributions.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,26 +176,34 @@ separator()
176176
# Vector case
177177
DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec),
178178
DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec),
179+
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
179180
DistSpec(:MvNormal, (mean, cov_num), norm_val_vec),
180181
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec),
181182
DistSpec(:MvNormal, (cov_mat,), norm_val_vec),
182183
DistSpec(:MvNormal, (cov_vec,), norm_val_vec),
184+
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_vec),
183185
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec),
184186
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec),
185187
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec),
188+
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
186189
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec),
187190
DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec),
188191
DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec),
192+
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec),
189193
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec),
190194
# Matrix case
191195
DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat),
196+
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
192197
DistSpec(:MvNormal, (mean, cov_num), norm_val_mat),
193198
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat),
194199
DistSpec(:MvNormal, (cov_vec,), norm_val_mat),
200+
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_mat),
195201
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat),
196202
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat),
203+
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
197204
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat),
198205
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
206+
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
199207
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
200208
]
201209

test/test_utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct DistSpec{Tθ<:Tuple, Tx}
1010
end
1111

1212
vectorize(v::Number) = [v]
13+
vectorize(v::Diagonal) = v.diag
1314
vectorize(v) = vec(v)
1415
pack(vals...) = reduce(vcat, vectorize.(vals))
1516
@generated function unpack(x, vals...)
@@ -22,6 +23,9 @@ pack(vals...) = reduce(vcat, vectorize.(vals))
2223
elseif T <: Vector
2324
push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1]))
2425
ind = :($ind + length(vals[$i]))
26+
elseif T <: Diagonal
27+
push!(unpacked, :(Diagonal(x[$ind:$ind+size(vals[$i],1)-1])))
28+
ind = :($ind + size(vals[$i],1))
2529
elseif T <: Matrix
2630
push!(unpacked, :(reshape(x[$ind:($ind+length(vals[$i])-1)], size(vals[$i]))))
2731
ind = :($ind + length(vals[$i]))

0 commit comments

Comments
 (0)