@@ -38,10 +38,18 @@ include("util.jl")
3838 t1, _ = AbstractMCMC. step (Random. default_rng (), model, StaticMH ([Normal (0 , 1 ), Normal (0 , 1 )]))
3939 t2, _ = AbstractMCMC. step (Random. default_rng (), model, MALA (x -> MvNormal (x, I)); initial_params= ones (2 ))
4040 for t in [t1, t2]
41- @test AbstractMCMC. getparams (t) == t. params
42- @test AbstractMCMC. setparams!! (model, t, AbstractMCMC. getparams (t)) == t
43- t_replaced = AbstractMCMC. setparams!! (model, t, (μ= 1.0 , σ= 2.0 ))
44- @test t_replaced. params == (μ= 1.0 , σ= 2.0 )
41+ @test AbstractMCMC. getparams (model, t) == t. params
42+
43+ new_transition = AbstractMCMC. setparams!! (model, t, AbstractMCMC. getparams (model, t))
44+ @test new_transition. lp == t. lp
45+ @test new_transition. accepted == t. accepted
46+ @test new_transition. params == t. params
47+ if hasfield (typeof (t), :gradient )
48+ @test new_transition. gradient == t. gradient
49+ end
50+
51+ t_replaced = AbstractMCMC. setparams!! (model, t, [1.0 , 2.0 ])
52+ @test t_replaced. params == [1.0 , 2.0 ]
4553 end
4654 end
4755
0 commit comments