diff --git a/HISTORY.md b/HISTORY.md index ddb0be6..c7b61af 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,7 @@ +## 0.13.2 + +Implemented `varname_leaves` for `LinearAlgebra.Cholesky`. + ## 0.13.1 Moved the functions `varname_leaves` and `varname_and_value_leaves` to AbstractPPL. diff --git a/Project.toml b/Project.toml index c687d8d..c1940ca 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.13.1" +version = "0.13.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varname/leaves.jl b/src/varname/leaves.jl index e4025e8..e0c8851 100644 --- a/src/varname/leaves.jl +++ b/src/varname/leaves.jl @@ -46,6 +46,25 @@ function varname_leaves(vn::VarName, val::NamedTuple) end return Iterators.flatten(iter) end +function varname_leaves(vn::VarName, val::LinearAlgebra.Cholesky) + return if val.uplo == 'L' + optic = Accessors.PropertyLens{:L}() + varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.L) + else + optic = Accessors.PropertyLens{:U}() + varname_leaves(VarName{getsym(vn)}(optic ∘ getoptic(vn)), val.U) + end +end +function varname_leaves(vn::VarName, x::LinearAlgebra.LowerTriangular) + return Iterators.map(Iterators.filter(I -> I[1] >= I[2], CartesianIndices(x))) do I + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + end +end +function varname_leaves(vn::VarName, x::LinearAlgebra.UpperTriangular) + return Iterators.map(Iterators.filter(I -> I[1] <= I[2], CartesianIndices(x))) do I + VarName{getsym(vn)}(Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) + end +end """ varname_and_value_leaves(vn::VarName, val) @@ -220,7 +239,6 @@ function varname_and_value_leaves_inner(vn::VarName, val::NamedTuple) end # Special types. function varname_and_value_leaves_inner(vn::VarName, x::LinearAlgebra.Cholesky) - # TODO: Or do we use `PDMat` here? return if x.uplo == 'L' varname_and_value_leaves_inner(Accessors.PropertyLens{:L}() ∘ vn, x.L) else diff --git a/test/varname.jl b/test/varname.jl index 2fbff9b..c4895e7 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -1,6 +1,7 @@ using Accessors using InvertedIndices using OffsetArrays +using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky using AbstractPPL: ⊑, ⊒, ⋢, ⋣, ≍ @@ -342,4 +343,83 @@ end end end end + + @testset "varname{_and_value}_leaves" begin + @testset "single value: float, int" begin + x = 1.0 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + x = 2 + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x), x)]) + end + + @testset "Vector" begin + x = randn(2) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x[1]), @varname(x[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) + x = [(; a=1), (; b=2)] + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x[1].a), @varname(x[2].b)]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == + Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) + end + + @testset "Matrix" begin + x = randn(2, 2) + @test Set(varname_leaves(@varname(x), x)) == Set([ + @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) + ]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "Lower/UpperTriangular" begin + x = randn(2, 2) + xl = LowerTriangular(x) + @test Set(varname_leaves(@varname(x), xl)) == + Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[2, 1]), x[2, 1]), + (@varname(x[2, 2]), x[2, 2]), + ]) + xu = UpperTriangular(x) + @test Set(varname_leaves(@varname(x), xu)) == + Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ + (@varname(x[1, 1]), x[1, 1]), + (@varname(x[1, 2]), x[1, 2]), + (@varname(x[2, 2]), x[2, 2]), + ]) + end + + @testset "NamedTuple" begin + x = (a=1.0, b=[2.0, 3.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) + ]) + end + + @testset "Cholesky" begin + x = cholesky([1.0 0.5; 0.5 1.0]) + @test Set(varname_leaves(@varname(x), x)) == + Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ + (@varname(x.U[1, 1]), x.U[1, 1]), + (@varname(x.U[1, 2]), x.U[1, 2]), + (@varname(x.U[2, 2]), x.U[2, 2]), + ]) + end + end end