Skip to content

Commit f820426

Browse files
committed
Add missing method for varname_leaves; add tests
1 parent 5505e05 commit f820426

File tree

4 files changed

+97
-2
lines changed

4 files changed

+97
-2
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.13.2
2+
3+
Implemented `varname_leaves` for `LinearAlgebra.Cholesky`.
4+
15
## 0.13.1
26

37
Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.13.1"
6+
version = "0.13.2"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varname/leaves.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,25 @@ function varname_leaves(vn::VarName, val::NamedTuple)
4646
end
4747
return Iterators.flatten(iter)
4848
end
49+
function varname_leaves(vn::VarName, val::LinearAlgebra.Cholesky)
50+
return if val.uplo == 'L'
51+
optic = Accessors.PropertyLens{:L}()
52+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), val.L)
53+
else
54+
optic = Accessors.PropertyLens{:U}()
55+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), val.U)
56+
end
57+
end
58+
function varname_leaves(vn::VarName, x::LinearAlgebra.LowerTriangular)
59+
return Iterators.map(Iterators.filter(I -> I[1] >= I[2], CartesianIndices(x))) do I
60+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn))
61+
end
62+
end
63+
function varname_leaves(vn::VarName, x::LinearAlgebra.UpperTriangular)
64+
return Iterators.map(Iterators.filter(I -> I[1] <= I[2], CartesianIndices(x))) do I
65+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn))
66+
end
67+
end
4968

5069
"""
5170
varname_and_value_leaves(vn::VarName, val)
@@ -220,7 +239,6 @@ function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
220239
end
221240
# Special types.
222241
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky)
223-
# TODO: Or do we use `PDMat` here?
224242
return if x.uplo == 'L'
225243
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
226244
else

test/varname.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Accessors
22
using InvertedIndices
33
using OffsetArrays
4+
using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky
45

56
using AbstractPPL: , , , ,
67

@@ -342,4 +343,76 @@ end
342343
end
343344
end
344345
end
346+
347+
@testset "varname{_and_value}_leaves" begin
348+
@testset "single value: float, int" begin
349+
x = 1.0
350+
@test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)])
351+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
352+
Set([(@varname(x), x)])
353+
x = 2
354+
@test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)])
355+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
356+
Set([(@varname(x), x)])
357+
end
358+
359+
@testset "Vector" begin
360+
x = randn(2)
361+
@test Set(varname_leaves(@varname(x), x)) ==
362+
Set([@varname(x[1]), @varname(x[2])])
363+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
364+
Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])])
365+
end
366+
367+
@testset "Matrix" begin
368+
x = randn(2, 2)
369+
@test Set(varname_leaves(@varname(x), x)) == Set([
370+
@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2])
371+
])
372+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([
373+
(@varname(x[1, 1]), x[1, 1]),
374+
(@varname(x[1, 2]), x[1, 2]),
375+
(@varname(x[2, 1]), x[2, 1]),
376+
(@varname(x[2, 2]), x[2, 2]),
377+
])
378+
end
379+
380+
@testset "Lower/UpperTriangular" begin
381+
x = randn(2, 2)
382+
xl = LowerTriangular(x)
383+
@test Set(varname_leaves(@varname(x), xl)) ==
384+
Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])])
385+
@test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([
386+
(@varname(x[1, 1]), x[1, 1]),
387+
(@varname(x[2, 1]), x[2, 1]),
388+
(@varname(x[2, 2]), x[2, 2]),
389+
])
390+
xu = UpperTriangular(x)
391+
@test Set(varname_leaves(@varname(x), xu)) ==
392+
Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])])
393+
@test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([
394+
(@varname(x[1, 1]), x[1, 1]),
395+
(@varname(x[1, 2]), x[1, 2]),
396+
(@varname(x[2, 2]), x[2, 2]),
397+
])
398+
end
399+
400+
@testset "NamedTuple" begin
401+
x = (a=1.0, b=2.0)
402+
@test Set(varname_leaves(@varname(x), x)) == Set([@varname(x.a), @varname(x.b)])
403+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
404+
Set([(@varname(x.a), x.a), (@varname(x.b), x.b)])
405+
end
406+
407+
@testset "Cholesky" begin
408+
x = cholesky([1.0 0.5; 0.5 1.0])
409+
@test Set(varname_leaves(@varname(x), x)) ==
410+
Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])])
411+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([
412+
(@varname(x.U[1, 1]), x.U[1, 1]),
413+
(@varname(x.U[1, 2]), x.U[1, 2]),
414+
(@varname(x.U[2, 2]), x.U[2, 2]),
415+
])
416+
end
417+
end
345418
end

0 commit comments

Comments
 (0)