Skip to content

Commit 3f6c538

Browse files
authored
Add missing method for varname_leaves; add tests (#141)
* Add missing method for `varname_leaves`; add tests * Add even more tests * Format
1 parent 5505e05 commit 3f6c538

File tree

4 files changed

+104
-2
lines changed

4 files changed

+104
-2
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.13.2
2+
3+
Implemented `varname_leaves` for `LinearAlgebra.Cholesky`.
4+
15
## 0.13.1
26

37
Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
33
keywords = ["probablistic programming"]
44
license = "MIT"
55
desc = "Common interfaces for probabilistic programming"
6-
version = "0.13.1"
6+
version = "0.13.2"
77

88
[deps]
99
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/varname/leaves.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,25 @@ function varname_leaves(vn::VarName, val::NamedTuple)
4646
end
4747
return Iterators.flatten(iter)
4848
end
49+
function varname_leaves(vn::VarName, val::LinearAlgebra.Cholesky)
50+
return if val.uplo == 'L'
51+
optic = Accessors.PropertyLens{:L}()
52+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), val.L)
53+
else
54+
optic = Accessors.PropertyLens{:U}()
55+
varname_leaves(VarName{getsym(vn)}(optic getoptic(vn)), val.U)
56+
end
57+
end
58+
function varname_leaves(vn::VarName, x::LinearAlgebra.LowerTriangular)
59+
return Iterators.map(Iterators.filter(I -> I[1] >= I[2], CartesianIndices(x))) do I
60+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn))
61+
end
62+
end
63+
function varname_leaves(vn::VarName, x::LinearAlgebra.UpperTriangular)
64+
return Iterators.map(Iterators.filter(I -> I[1] <= I[2], CartesianIndices(x))) do I
65+
VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) getoptic(vn))
66+
end
67+
end
4968

5069
"""
5170
varname_and_value_leaves(vn::VarName, val)
@@ -220,7 +239,6 @@ function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple)
220239
end
221240
# Special types.
222241
function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky)
223-
# TODO: Or do we use `PDMat` here?
224242
return if x.uplo == 'L'
225243
varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() vn, x.L)
226244
else

test/varname.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Accessors
22
using InvertedIndices
33
using OffsetArrays
4+
using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky
45

56
using AbstractPPL: , , , ,
67

@@ -342,4 +343,83 @@ end
342343
end
343344
end
344345
end
346+
347+
@testset "varname{_and_value}_leaves" begin
348+
@testset "single value: float, int" begin
349+
x = 1.0
350+
@test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)])
351+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
352+
Set([(@varname(x), x)])
353+
x = 2
354+
@test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)])
355+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
356+
Set([(@varname(x), x)])
357+
end
358+
359+
@testset "Vector" begin
360+
x = randn(2)
361+
@test Set(varname_leaves(@varname(x), x)) ==
362+
Set([@varname(x[1]), @varname(x[2])])
363+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
364+
Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])])
365+
x = [(; a=1), (; b=2)]
366+
@test Set(varname_leaves(@varname(x), x)) ==
367+
Set([@varname(x[1].a), @varname(x[2].b)])
368+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) ==
369+
Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)])
370+
end
371+
372+
@testset "Matrix" begin
373+
x = randn(2, 2)
374+
@test Set(varname_leaves(@varname(x), x)) == Set([
375+
@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2])
376+
])
377+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([
378+
(@varname(x[1, 1]), x[1, 1]),
379+
(@varname(x[1, 2]), x[1, 2]),
380+
(@varname(x[2, 1]), x[2, 1]),
381+
(@varname(x[2, 2]), x[2, 2]),
382+
])
383+
end
384+
385+
@testset "Lower/UpperTriangular" begin
386+
x = randn(2, 2)
387+
xl = LowerTriangular(x)
388+
@test Set(varname_leaves(@varname(x), xl)) ==
389+
Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])])
390+
@test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([
391+
(@varname(x[1, 1]), x[1, 1]),
392+
(@varname(x[2, 1]), x[2, 1]),
393+
(@varname(x[2, 2]), x[2, 2]),
394+
])
395+
xu = UpperTriangular(x)
396+
@test Set(varname_leaves(@varname(x), xu)) ==
397+
Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])])
398+
@test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([
399+
(@varname(x[1, 1]), x[1, 1]),
400+
(@varname(x[1, 2]), x[1, 2]),
401+
(@varname(x[2, 2]), x[2, 2]),
402+
])
403+
end
404+
405+
@testset "NamedTuple" begin
406+
x = (a=1.0, b=[2.0, 3.0])
407+
@test Set(varname_leaves(@varname(x), x)) ==
408+
Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])])
409+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([
410+
(@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2])
411+
])
412+
end
413+
414+
@testset "Cholesky" begin
415+
x = cholesky([1.0 0.5; 0.5 1.0])
416+
@test Set(varname_leaves(@varname(x), x)) ==
417+
Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])])
418+
@test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([
419+
(@varname(x.U[1, 1]), x.U[1, 1]),
420+
(@varname(x.U[1, 2]), x.U[1, 2]),
421+
(@varname(x.U[2, 2]), x.U[2, 2]),
422+
])
423+
end
424+
end
345425
end

0 commit comments

Comments
 (0)