Skip to content

Commit e902811

Browse files
committed
fix test errors
1 parent 58162a3 commit e902811

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

src/AdvancedMH.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,13 @@ function AbstractMCMC.getparams(t::Transition)
146146
return t.params
147147
end
148148

149-
function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, t::Transition, params)
150-
t = BangBang.setproperty!!(t, :params, params)
151-
return BangBang.setproperty!!(t, :lp, LogDensityProblems.logdensity(model.logdensity, params))
149+
# TODO (sunxd): remove `DensityModel` in favor of `AbstractMCMC.LogDensityModel`
150+
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::Transition, params)
151+
return Transition(
152+
params,
153+
logdensity(model, params),
154+
t.accepted
155+
)
152156
end
153157

154158
# Include inference methods.

src/MALA.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ function AbstractMCMC.getparams(t::GradientTransition)
2424
return t.params
2525
end
2626

27-
function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, t::GradientTransition, params)
28-
return AdvancedMH.GradientTransition(
27+
function AbstractMCMC.setparams!!(model::DensityModelOrLogDensityModel, t::GradientTransition, params)
28+
lp, gradient = logdensity_and_gradient(model, params)
29+
return GradientTransition(
2930
params,
30-
AdvancedMH.logdensity_and_gradient(model.logdensity, params)...,
31+
lp,
32+
gradient,
3133
t.accepted
3234
)
3335
end

test/runtests.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,18 @@ include("util.jl")
3838
t1, _ = AbstractMCMC.step(Random.default_rng(), model, StaticMH([Normal(0, 1), Normal(0, 1)]))
3939
t2, _ = AbstractMCMC.step(Random.default_rng(), model, MALA(x -> MvNormal(x, I)); initial_params=ones(2))
4040
for t in [t1, t2]
41-
@test AbstractMCMC.getparams(t) == t.params
42-
@test AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(t)) == t
43-
t_replaced = AbstractMCMC.setparams!!(model, t, (μ=1.0, σ=2.0))
44-
@test t_replaced.params ===1.0, σ=2.0)
41+
@test AbstractMCMC.getparams(model, t) == t.params
42+
43+
new_transition = AbstractMCMC.setparams!!(model, t, AbstractMCMC.getparams(model, t))
44+
@test new_transition.lp == t.lp
45+
@test new_transition.accepted == t.accepted
46+
@test new_transition.params == t.params
47+
if hasfield(typeof(t), :gradient)
48+
@test new_transition.gradient == t.gradient
49+
end
50+
51+
t_replaced = AbstractMCMC.setparams!!(model, t, [1.0, 2.0])
52+
@test t_replaced.params == [1.0, 2.0]
4553
end
4654
end
4755

0 commit comments

Comments
 (0)