Skip to content

Commit 1f1bbd0

Browse files
committed
hash, eq for PFState; tests
1 parent ebeee8e commit 1f1bbd0

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

src/inference/particle_filter.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,28 @@ mutable struct ParticleFilterState{U}
2424
end
2525

2626
function Base.copy(state::ParticleFilterState{U}) where U
27-
ParticleFilterState{U}(
28-
copy(state.traces),
29-
copy(state.new_traces),
30-
copy(state.log_weights),
27+
return ParticleFilterState{U}(
28+
Base.copy(state.traces),
29+
Base.copy(state.new_traces),
30+
Base.copy(state.log_weights),
3131
state.log_ml_est,
32-
copy(state.parents)
32+
Base.copy(state.parents)
3333
)
3434
end
35+
function Base.:(==)(a::ParticleFilterState{U}, b::ParticleFilterState{V}) where {U, V}
36+
return U == V &&
37+
a.traces == b.traces &&
38+
a.log_weights == b.log_weights &&
39+
a.log_ml_est == b.log_ml_est &&
40+
a.parents == b.parents
41+
end
42+
function Base.hash(state::ParticleFilterState{U}, h::UInt) where {U}
43+
return hash(U,
44+
hash(state.traces,
45+
hash(state.log_weights .+ (state.log_ml_est - log(length(state.log_weights))),
46+
hash(state.parents,
47+
h))))
48+
end
3549

3650
"""
3751
traces = get_traces(state::ParticleFilterState)

test/inference/particle_filter.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,31 @@ end
169169
@test isapprox(expected_log_ml, actual_log_ml_est, atol=0.02)
170170
end
171171

172+
@testset "PF state" begin
173+
@gen function _foo()
174+
x ~ normal(0, 1)
175+
end
176+
st = Gen.ParticleFilterState{Gen.Trace}(
177+
[simulate(_foo, ()) for _=1:10],
178+
Vector{Gen.Trace}(undef, 10),
179+
[0. for _=1:10],
180+
0.,
181+
collect(1:10)
182+
)
183+
st2 = copy(st)
184+
@test st == st2
185+
@test hash(st) == hash(st2)
186+
st2.log_weights[1] = 1.
187+
@test st.log_weights[1] == 0.
188+
@test st != st2
189+
@test hash(st) != hash(st2)
190+
191+
# test that the other fields are independent copies too:
192+
st.traces[1] = simulate(_foo, ())
193+
@test st.traces != st2.traces
194+
st.log_ml_est = 1.
195+
@test st.log_ml_est == 1.
196+
st.parents[1] = 5
197+
@test st2.parents[1] == 1
172198
end
199+
end

0 commit comments

Comments
 (0)