Skip to content

Commit a2fa27b

Browse files
committed
fix tests
1 parent cabc01e commit a2fa27b

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

test/varinfo.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using DynamicPPL: Selector, reconstruct, invlink, CACHERESET,
66
set_flag!, unset_flag!, VarInfo, TypedVarInfo,
77
getlogp, setlogp!, resetlogp!, acclogp!, vectorize,
88
setorder!, updategid!
9-
using DynamicPPL
9+
using DynamicPPL, LinearAlgebra
1010
using Distributions
1111
using ForwardDiff: Dual
1212
using Test
@@ -167,32 +167,32 @@ include(dir*"/test/test_utils/AllUtils.jl")
167167
meta = vi.metadata
168168
model(vi, SampleFromUniform())
169169

170-
@test all(x -> ~istrans(vi, x), meta.vns)
170+
@test all(x -> istrans(vi, x), meta.vns)
171171
alg = HMC(0.1, 5)
172172
spl = Sampler(alg, model)
173173
v = copy(meta.vals)
174-
link!(vi, spl)
175-
@test all(x -> istrans(vi, x), meta.vns)
176174
invlink!(vi, spl)
177175
@test all(x -> ~istrans(vi, x), meta.vns)
178-
@test meta.vals == v
176+
link!(vi, spl)
177+
@test all(x -> istrans(vi, x), meta.vns)
178+
@test norm(meta.vals - v) <= 1e-6
179179

180180
vi = TypedVarInfo(vi)
181181
meta = vi.metadata
182182
alg = HMC(0.1, 5)
183183
spl = Sampler(alg, model)
184-
@test all(x -> ~istrans(vi, x), meta.s.vns)
185-
@test all(x -> ~istrans(vi, x), meta.m.vns)
186-
v_s = copy(meta.s.vals)
187-
v_m = copy(meta.m.vals)
188-
link!(vi, spl)
189184
@test all(x -> istrans(vi, x), meta.s.vns)
190185
@test all(x -> istrans(vi, x), meta.m.vns)
186+
v_s = copy(meta.s.vals)
187+
v_m = copy(meta.m.vals)
191188
invlink!(vi, spl)
192189
@test all(x -> ~istrans(vi, x), meta.s.vns)
193190
@test all(x -> ~istrans(vi, x), meta.m.vns)
194-
@test meta.s.vals == v_s
195-
@test meta.m.vals == v_m
191+
link!(vi, spl)
192+
@test all(x -> istrans(vi, x), meta.s.vns)
193+
@test all(x -> istrans(vi, x), meta.m.vns)
194+
@test norm(meta.s.vals - v_s) <= 1e-6
195+
@test norm(meta.m.vals - v_m) <= 1e-6
196196
end
197197
@testset "setgid!" begin
198198
vi = VarInfo()

0 commit comments

Comments
 (0)