Skip to content

Commit 68afaad

Browse files
committed
Move link!!/invlink!! tests back into DynamicPPL
1 parent 356a787 commit 68afaad

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/varinfo.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,57 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
387387
end
388388
end
389389

390+
@testset "link!! and invlink!!" begin
391+
@model gdemo(x, y) = begin
392+
s ~ InverseGamma(2, 3)
393+
m ~ Uniform(0, 2)
394+
x ~ Normal(m, sqrt(s))
395+
y ~ Normal(m, sqrt(s))
396+
end
397+
model = gdemo(1.0, 2.0)
398+
399+
# Check that instantiating the model does not perform linking
400+
vi = VarInfo()
401+
meta = vi.metadata
402+
model(vi, SampleFromUniform())
403+
@test all(x -> !istrans(vi, x), meta.vns)
404+
405+
# Check that linking and invlinking set the `trans` flag accordingly
406+
v = copy(meta.vals)
407+
link!!(vi, model)
408+
@test all(x -> istrans(vi, x), meta.vns)
409+
invlink!!(vi, model)
410+
@test all(x -> !istrans(vi, x), meta.vns)
411+
@test meta.vals v atol = 1e-10
412+
413+
# Check that linking and invlinking preserves the values
414+
vi = TypedVarInfo(vi)
415+
meta = vi.metadata
416+
@test all(x -> !istrans(vi, x), meta.s.vns)
417+
@test all(x -> !istrans(vi, x), meta.m.vns)
418+
v_s = copy(meta.s.vals)
419+
v_m = copy(meta.m.vals)
420+
link!!(vi, model)
421+
@test all(x -> istrans(vi, x), meta.s.vns)
422+
@test all(x -> istrans(vi, x), meta.m.vns)
423+
invlink!!(vi, model)
424+
@test all(x -> !istrans(vi, x), meta.s.vns)
425+
@test all(x -> !istrans(vi, x), meta.m.vns)
426+
@test meta.s.vals v_s atol = 1e-10
427+
@test meta.m.vals v_m atol = 1e-10
428+
429+
# Transform only one variable (`s`) but not the others (`m`)
430+
spl = DynamicPPL.Sampler(MySAlg(), model)
431+
link!!(vi, spl, model)
432+
@test all(x -> istrans(vi, x), meta.s.vns)
433+
@test all(x -> !istrans(vi, x), meta.m.vns)
434+
invlink!!(vi, spl, model)
435+
@test all(x -> !istrans(vi, x), meta.s.vns)
436+
@test all(x -> !istrans(vi, x), meta.m.vns)
437+
@test meta.s.vals v_s atol = 1e-10
438+
@test meta.m.vals v_m atol = 1e-10
439+
end
440+
390441
@testset "istrans" begin
391442
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
392443
model = demo_constrained()

0 commit comments

Comments
 (0)