File tree Expand file tree Collapse file tree 2 files changed +46
-5
lines changed Expand file tree Collapse file tree 2 files changed +46
-5
lines changed Original file line number Diff line number Diff line change @@ -24,14 +24,28 @@ mutable struct ParticleFilterState{U}
2424end
2525
2626function Base. copy (state:: ParticleFilterState{U} ) where U
27- ParticleFilterState {U} (
28- copy (state. traces),
29- copy (state. new_traces),
30- copy (state. log_weights),
27+ return ParticleFilterState {U} (
28+ Base . copy (state. traces),
29+ Base . copy (state. new_traces),
30+ Base . copy (state. log_weights),
3131 state. log_ml_est,
32- copy (state. parents)
32+ Base . copy (state. parents)
3333 )
3434end
35+ function Base.:(== )(a:: ParticleFilterState{U} , b:: ParticleFilterState{V} ) where {U, V}
36+ return U == V &&
37+ a. traces == b. traces &&
38+ a. log_weights == b. log_weights &&
39+ a. log_ml_est == b. log_ml_est &&
40+ a. parents == b. parents
41+ end
42+ function Base. hash (state:: ParticleFilterState{U} , h:: UInt ) where {U}
43+ return hash (U,
44+ hash (state. traces,
45+ hash (state. log_weights .+ (state. log_ml_est - log (length (state. log_weights))),
46+ hash (state. parents,
47+ h))))
48+ end
3549
3650"""
3751 traces = get_traces(state::ParticleFilterState)
Original file line number Diff line number Diff line change 169169 @test isapprox (expected_log_ml, actual_log_ml_est, atol= 0.02 )
170170end
171171
172+ @testset " PF state" begin
173+ @gen function _foo ()
174+ x ~ normal (0 , 1 )
175+ end
176+ st = Gen. ParticleFilterState {Gen.Trace} (
177+ [simulate (_foo, ()) for _= 1 : 10 ],
178+ Vector {Gen.Trace} (undef, 10 ),
179+ [0. for _= 1 : 10 ],
180+ 0. ,
181+ collect (1 : 10 )
182+ )
183+ st2 = copy (st)
184+ @test st == st2
185+ @test hash (st) == hash (st2)
186+ st2. log_weights[1 ] = 1.
187+ @test st. log_weights[1 ] == 0.
188+ @test st != st2
189+ @test hash (st) != hash (st2)
190+
191+ # test that the other fields are independent copies too:
192+ st. traces[1 ] = simulate (_foo, ())
193+ @test st. traces != st2. traces
194+ st. log_ml_est = 1.
195+ @test st. log_ml_est == 1.
196+ st. parents[1 ] = 5
197+ @test st2. parents[1 ] == 1
172198end
199+ end
You can’t perform that action at this time.
0 commit comments