Skip to content

Commit dc6001c

Browse files
committed
updates
1 parent 8962d40 commit dc6001c

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/AbstractMCMC.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ struct MCMCSerial <: AbstractMCMCEnsemble end
8282

8383
function condition end
8484

85+
function logdensity_and_state end
86+
8587
include("samplingstats.jl")
8688
include("logging.jl")
8789
include("interface.jl")

src/gibbs.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,7 @@ function AbstractMCMC.step(
203203
logdensity_model.logdensity, conditioning_variables_values
204204
)
205205

206-
# recompute the logdensity stored in the mcmc state, because the values might have been updated in other sub-problems
207-
updated_log_prob = LogDensityProblems.logdensity(cond_logdensity, sub_state)
208-
209-
if !hasproperty(sub_state, :logp)
210-
error(
211-
"$(typeof(sub_state)) does not have a `:logp` field, which is required by Gibbs sampling",
212-
)
213-
end
214-
sub_state = BangBang.setproperty!!(sub_state, :logp, updated_log_prob)
215-
206+
_, sub_state = AbstractMCMC.logdensity_and_state(cond_logdensity, sub_state)
216207
sub_state = last(
217208
AbstractMCMC.step(
218209
rng,

test/gibbs_example/mh.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ struct MHTransition{T}
99
params::Vector{T}
1010
end
1111

12-
function AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state::MHState)
13-
# recompute the logdensity, instead of using the one stored in the state
14-
return AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params)
12+
function AbstractMCMC.logdensity_and_state(
13+
logdensity_function, state::MHState; recompute_logp::Bool=true
14+
)
15+
if recompute_logp
16+
logp, substate = AbstractMCMC.LogDensityProblems.logdensity(logdensity_function, state.params)
17+
return logp, MHState(substate.params, logp)
18+
else
19+
return state.logp, state
20+
end
1521
end
1622

1723
function Base.vec(state::MHState)

0 commit comments

Comments
 (0)