@@ -17,7 +17,11 @@ using ADTypes: AutoForwardDiff
17
17
for vn in [@varname (x), :x ]
18
18
for getlogprob in [DynamicPPL. getlogprior, DynamicPPL. getlogjoint]
19
19
marginalized = marginalize (
20
- model, [vn], vi, getlogprob; hess_adtype= AutoForwardDiff ()
20
+ model,
21
+ [vn];
22
+ varinfo= vi,
23
+ getlogprob= getlogprob,
24
+ hess_adtype= AutoForwardDiff (),
21
25
)
22
26
for y in range (- 5 , 5 ; length= 100 )
23
27
@test marginalized ([y]) ≈ logpdf (Normal (0 , 1 ), y) atol = 1e-5
@@ -36,27 +40,29 @@ using ADTypes: AutoForwardDiff
36
40
vi_linked = DynamicPPL. link (vi_unlinked, model)
37
41
38
42
@testset " unlinked VarInfo" begin
39
- mx = marginalize (model, [@varname (x)], vi_unlinked)
43
+ mx = marginalize (model, [@varname (x)]; varinfo = vi_unlinked)
40
44
for x in range (0.01 , 0.99 ; length= 10 )
41
45
@test mx ([x]) ≈ logpdf (Beta (2 , 2 ), x)
42
46
end
43
47
# generally when marginalising Beta it doesn't go to zero
44
- my = marginalize (model, [@varname (y)], vi_unlinked)
48
+ # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067
49
+ my = marginalize (model, [@varname (y)]; varinfo= vi_unlinked)
45
50
diff = my ([0.0 ]) - logpdf (Normal (), 0.0 )
46
51
for x in range (- 5 , 5 ; length= 10 )
47
52
@test my ([x]) ≈ logpdf (Normal (), x) + diff
48
53
end
49
54
end
50
55
51
56
@testset " linked VarInfo" begin
52
- mx = marginalize (model, [@varname (x)], vi_linked)
57
+ mx = marginalize (model, [@varname (x)]; varinfo = vi_linked)
53
58
binv = Bijectors. inverse (Bijectors. bijector (Beta (2 , 2 )))
54
59
for y_linked in range (- 5 , 5 ; length= 10 )
55
60
y_unlinked = binv (y_linked)
56
61
@test mx ([y_linked]) ≈ logpdf (Beta (2 , 2 ), y_unlinked)
57
62
end
58
63
# generally when marginalising Beta it doesn't go to zero
59
- my = marginalize (model, [@varname (y)], vi_linked)
64
+ # https://github.com/TuringLang/DynamicPPL.jl/pull/1036#discussion_r2349388067
65
+ my = marginalize (model, [@varname (y)]; varinfo= vi_linked)
60
66
diff = my ([0.0 ]) - logpdf (Normal (), 0.0 )
61
67
for x in range (- 5 , 5 ; length= 10 )
62
68
@test my ([x]) ≈ logpdf (Normal (), x) + diff
0 commit comments