Skip to content

Commit 3287746

Browse files
authored
Merge pull request #509 from probcomp/20230620pfstate-copy
add Base.copy(::ParticleFilterState)
2 parents 262f96f + 027177b commit 3287746

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/inference/particle_filter.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,28 @@ mutable struct ParticleFilterState{U}
2323
parents::Vector{Int}
2424
end
2525

26+
function Base.copy(state::ParticleFilterState{U}) where U
27+
return ParticleFilterState{U}(
28+
Base.copy(state.traces),
29+
Base.copy(state.new_traces),
30+
Base.copy(state.log_weights),
31+
state.log_ml_est,
32+
Base.copy(state.parents)
33+
)
34+
end
35+
function Base.:(==)(a::ParticleFilterState{U}, b::ParticleFilterState{V}) where {U, V}
36+
return U == V &&
37+
a.traces == b.traces &&
38+
a.log_weights == b.log_weights &&
39+
a.log_ml_est == b.log_ml_est &&
40+
a.parents == b.parents
41+
end
42+
function Base.hash(state::ParticleFilterState{U}, h::UInt) where {U}
43+
return hash(U, hash(state.traces,
44+
hash(state.log_weights,
45+
hash(state.log_ml_est, hash(state.parents, h)))))
46+
end
47+
2648
"""
2749
traces = get_traces(state::ParticleFilterState)
2850

test/inference/particle_filter.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,31 @@ end
169169
@test isapprox(expected_log_ml, actual_log_ml_est, atol=0.02)
170170
end
171171

172+
@testset "PF state" begin
173+
@gen function _foo()
174+
x ~ normal(0, 1)
175+
end
176+
st = Gen.ParticleFilterState{Gen.Trace}(
177+
[simulate(_foo, ()) for _=1:10],
178+
Vector{Gen.Trace}(undef, 10),
179+
[0. for _=1:10],
180+
0.,
181+
collect(1:10)
182+
)
183+
st2 = copy(st)
184+
@test st == st2
185+
@test hash(st) == hash(st2)
186+
st2.log_weights[1] = 1.
187+
@test st.log_weights[1] == 0.
188+
@test st != st2
189+
@test hash(st) != hash(st2)
190+
191+
# test that the other fields are independent copies too:
192+
st.traces[1] = simulate(_foo, ())
193+
@test st.traces != st2.traces
194+
st.log_ml_est = 1.
195+
@test st.log_ml_est == 1.
196+
st.parents[1] = 5
197+
@test st2.parents[1] == 1
172198
end
199+
end

0 commit comments

Comments
 (0)