|
1 | 1 | using Accessors
|
2 | 2 | using InvertedIndices
|
3 | 3 | using OffsetArrays
|
| 4 | +using LinearAlgebra: LowerTriangular, UpperTriangular, cholesky |
4 | 5 |
|
5 | 6 | using AbstractPPL: ⊑, ⊒, ⋢, ⋣, ≍
|
6 | 7 |
|
|
342 | 343 | end
|
343 | 344 | end
|
344 | 345 | end
|
| 346 | + |
| 347 | + @testset "varname{_and_value}_leaves" begin |
| 348 | + @testset "single value: float, int" begin |
| 349 | + x = 1.0 |
| 350 | + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) |
| 351 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == |
| 352 | + Set([(@varname(x), x)]) |
| 353 | + x = 2 |
| 354 | + @test Set(varname_leaves(@varname(x), x)) == Set([@varname(x)]) |
| 355 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == |
| 356 | + Set([(@varname(x), x)]) |
| 357 | + end |
| 358 | + |
| 359 | + @testset "Vector" begin |
| 360 | + x = randn(2) |
| 361 | + @test Set(varname_leaves(@varname(x), x)) == |
| 362 | + Set([@varname(x[1]), @varname(x[2])]) |
| 363 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == |
| 364 | + Set([(@varname(x[1]), x[1]), (@varname(x[2]), x[2])]) |
| 365 | + x = [(; a=1), (; b=2)] |
| 366 | + @test Set(varname_leaves(@varname(x), x)) == |
| 367 | + Set([@varname(x[1].a), @varname(x[2].b)]) |
| 368 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == |
| 369 | + Set([(@varname(x[1].a), x[1].a), (@varname(x[2].b), x[2].b)]) |
| 370 | + end |
| 371 | + |
| 372 | + @testset "Matrix" begin |
| 373 | + x = randn(2, 2) |
| 374 | + @test Set(varname_leaves(@varname(x), x)) == Set([ |
| 375 | + @varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 1]), @varname(x[2, 2]) |
| 376 | + ]) |
| 377 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ |
| 378 | + (@varname(x[1, 1]), x[1, 1]), |
| 379 | + (@varname(x[1, 2]), x[1, 2]), |
| 380 | + (@varname(x[2, 1]), x[2, 1]), |
| 381 | + (@varname(x[2, 2]), x[2, 2]), |
| 382 | + ]) |
| 383 | + end |
| 384 | + |
| 385 | + @testset "Lower/UpperTriangular" begin |
| 386 | + x = randn(2, 2) |
| 387 | + xl = LowerTriangular(x) |
| 388 | + @test Set(varname_leaves(@varname(x), xl)) == |
| 389 | + Set([@varname(x[1, 1]), @varname(x[2, 1]), @varname(x[2, 2])]) |
| 390 | + @test Set(collect(varname_and_value_leaves(@varname(x), xl))) == Set([ |
| 391 | + (@varname(x[1, 1]), x[1, 1]), |
| 392 | + (@varname(x[2, 1]), x[2, 1]), |
| 393 | + (@varname(x[2, 2]), x[2, 2]), |
| 394 | + ]) |
| 395 | + xu = UpperTriangular(x) |
| 396 | + @test Set(varname_leaves(@varname(x), xu)) == |
| 397 | + Set([@varname(x[1, 1]), @varname(x[1, 2]), @varname(x[2, 2])]) |
| 398 | + @test Set(collect(varname_and_value_leaves(@varname(x), xu))) == Set([ |
| 399 | + (@varname(x[1, 1]), x[1, 1]), |
| 400 | + (@varname(x[1, 2]), x[1, 2]), |
| 401 | + (@varname(x[2, 2]), x[2, 2]), |
| 402 | + ]) |
| 403 | + end |
| 404 | + |
| 405 | + @testset "NamedTuple" begin |
| 406 | + x = (a=1.0, b=[2.0, 3.0]) |
| 407 | + @test Set(varname_leaves(@varname(x), x)) == |
| 408 | + Set([@varname(x.a), @varname(x.b[1]), @varname(x.b[2])]) |
| 409 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ |
| 410 | + (@varname(x.a), x.a), (@varname(x.b[1]), x.b[1]), (@varname(x.b[2]), x.b[2]) |
| 411 | + ]) |
| 412 | + end |
| 413 | + |
| 414 | + @testset "Cholesky" begin |
| 415 | + x = cholesky([1.0 0.5; 0.5 1.0]) |
| 416 | + @test Set(varname_leaves(@varname(x), x)) == |
| 417 | + Set([@varname(x.U[1, 1]), @varname(x.U[1, 2]), @varname(x.U[2, 2])]) |
| 418 | + @test Set(collect(varname_and_value_leaves(@varname(x), x))) == Set([ |
| 419 | + (@varname(x.U[1, 1]), x.U[1, 1]), |
| 420 | + (@varname(x.U[1, 2]), x.U[1, 2]), |
| 421 | + (@varname(x.U[2, 2]), x.U[2, 2]), |
| 422 | + ]) |
| 423 | + end |
| 424 | + end |
345 | 425 | end
|
0 commit comments