@@ -483,23 +483,34 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
483
483
[@varname (s), @varname (m), @varname (x[2 ])],
484
484
[@varname (s), @varname (x[1 ]), @varname (x[2 ])],
485
485
[@varname (m), @varname (x[1 ]), @varname (x[2 ])],
486
- [@varname (s), @varname (m), @varname (x[1 ]), @varname (x[2 ])],
487
486
]
488
487
489
- # `SimpleaVarInfo` only supports subsetting using the varnames as they appear
488
+ # Patterns requiring `subsumes`.
489
+ vns_supported_with_subsumes = [
490
+ [@varname (s), @varname (x)] => [@varname (s), @varname (x[1 ]), @varname (x[2 ])],
491
+ [@varname (m), @varname (x)] => [@varname (m), @varname (x[1 ]), @varname (x[2 ])],
492
+ [@varname (s), @varname (m), @varname (x)] =>
493
+ [@varname (s), @varname (m), @varname (x[1 ]), @varname (x[2 ])],
494
+ ]
495
+
496
+ # `SimpleVarInfo` only supports subsetting using the varnames as they appear
490
497
# in the model.
491
498
vns_supported_simple = filter (∈ (vns), vns_supported_standard)
492
499
493
- @testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos_standard
500
+ @testset " $(short_varinfo_name (varinfo)) " for varinfo in varinfos
494
501
# All variables.
495
502
check_varinfo_keys (varinfo, vns)
496
503
497
504
# Added a `convert` to make the naming of the testsets a bit more readable.
498
- vns_supported = if varinfo isa DynamicPPL. SimpleOrThreadSafeSimple
499
- vns_supported_simple
500
- else
501
- vns_supported_standard
502
- end
505
+ # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames,
506
+ # # i.e. `VarName{sym}()` without any indexing, etc.
507
+ vns_supported =
508
+ if varinfo isa DynamicPPL. SimpleOrThreadSafeSimple &&
509
+ values_as (varinfo) isa NamedTuple
510
+ vns_supported_simple
511
+ else
512
+ vns_supported_standard
513
+ end
503
514
@testset " $(convert (Vector{VarName}, vns_subset)) " for vns_subset in
504
515
vns_supported
505
516
varinfo_subset = subset (varinfo, vns_subset)
@@ -516,6 +527,24 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
516
527
# Values should be the same.
517
528
@test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
518
529
end
530
+
531
+ @testset " $(convert (Vector{VarName}, vns_subset)) " for (
532
+ vns_subset, vns_target
533
+ ) in vns_supported_with_subsumes
534
+ varinfo_subset = subset (varinfo, vns_subset)
535
+ # Should now only contain the variables in `vns_subset`.
536
+ check_varinfo_keys (varinfo_subset, vns_target)
537
+ # Values should be the same.
538
+ @test [varinfo_subset[vn] for vn in vns_target] == [varinfo[vn] for vn in vns_target]
539
+
540
+ # `merge` with the original.
541
+ varinfo_merged = merge (varinfo, varinfo_subset)
542
+ vns_merged = keys (varinfo_merged)
543
+ # Should be equivalent.
544
+ check_varinfo_keys (varinfo_merged, vns)
545
+ # Values should be the same.
546
+ @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns]
547
+ end
519
548
end
520
549
521
550
# For certain varinfos we should have errors.
@@ -526,15 +555,6 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
526
555
varinfo, [@varname (s), @varname (m), @varname (x[1 ])]
527
556
)
528
557
end
529
- # `SimpleVarInfo{<:AbstractDict}` can only handle varnames as they appear in the model.
530
- varinfo = varinfos[findfirst (
531
- Base. Fix2 (isa, SimpleVarInfo{<: AbstractDict }), varinfos
532
- )]
533
- @testset " $(short_varinfo_name (varinfo)) : failure cases" begin
534
- @test_throws ArgumentError subset (
535
- varinfo, [@varname (s), @varname (m), @varname (x)]
536
- )
537
- end
538
558
end
539
559
540
560
@testset " merge" begin
0 commit comments