Skip to content

Commit 0dee0f8

Browse files
JaimeRZPdevmotionyebai
authored
Transition (#2026)
* 1st commit * should be working * connector to old struct * different type * bug * bug * Update src/inference/Inference.jl David's suggestion Co-authored-by: David Widmann <[email protected]> * Update src/inference/Inference.jl Co-authored-by: David Widmann <[email protected]> * remove reference to gibbs transitions int tests * Update src/inference/Inference.jl --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 3e8d97f commit 0dee0f8

File tree

4 files changed

+25
-44
lines changed

4 files changed

+25
-44
lines changed

src/inference/Inference.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,28 @@ end
105105
# Default Transition #
106106
######################
107107

108-
struct Transition{T, F<:AbstractFloat}
109-
θ :: T
110-
lp :: F
108+
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
109+
θ :: T
110+
lp :: F # TODO: merge `lp` with `stat`
111+
stat :: S
111112
end
112113

113-
function Transition(vi::AbstractVarInfo, nt::NamedTuple=NamedTuple())
114-
theta = merge(tonamedtuple(vi), nt)
114+
Transition(θ, lp) = Transition(θ, lp, nothing)
115+
116+
function Transition(vi::AbstractVarInfo; nt::NamedTuple=NamedTuple())
117+
θ = merge(tonamedtuple(vi), nt)
115118
lp = getlogp(vi)
116-
return Transition{typeof(theta), typeof(lp)}(theta, lp)
119+
return Transition(θ, lp, nothing)
117120
end
118121

119-
metadata(t::Transition) = (lp = t.lp,)
122+
function metadata(t::Transition)
123+
stat = t.stat
124+
if stat === nothing
125+
return (lp = t.lp,)
126+
else
127+
return merge((lp = t.lp,), stat)
128+
end
129+
end
120130

121131
DynamicPPL.getlogp(t::Transition) = t.lp
122132

src/inference/gibbs.jl

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,6 @@ struct GibbsState{V<:VarInfo,S<:Tuple{Vararg{Sampler}},T}
6969
states::T
7070
end
7171

72-
struct GibbsTransition{T,F}
73-
"The parameters for any given sample."
74-
θ::T
75-
"The joint log probability for the sample's parameters."
76-
lp::F
77-
end
78-
79-
function GibbsTransition(vi::AbstractVarInfo)
80-
theta = tonamedtuple(vi)
81-
lp = getlogp(vi)
82-
return GibbsTransition(theta, lp)
83-
end
84-
85-
metadata(t::GibbsTransition) = (lp = t.lp,)
86-
87-
DynamicPPL.getlogp(t::GibbsTransition) = t.lp
88-
8972
# extract varinfo object from state
9073
"""
9174
gibbs_varinfo(model, sampler, state)
@@ -213,7 +196,7 @@ function DynamicPPL.initialstep(
213196
end
214197

215198
# Compute initial transition and state.
216-
transition = GibbsTransition(vi)
199+
transition = Transition(vi)
217200
state = GibbsState(vi, samplers, states)
218201

219202
return transition, state
@@ -248,5 +231,5 @@ function AbstractMCMC.step(
248231
return newstate
249232
end
250233

251-
return GibbsTransition(vi), GibbsState(vi, samplers, states)
234+
return Transition(vi), GibbsState(vi, samplers, states)
252235
end

src/inference/hmc.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,12 @@ end
2121
# Hamiltonian Transition #
2222
##########################
2323

24-
struct HMCTransition{T,NT<:NamedTuple,F<:AbstractFloat}
25-
θ::T
26-
lp::F
27-
stat::NT
28-
end
29-
30-
function HMCTransition(vi::AbstractVarInfo, t::AHMC.Transition)
24+
function Transition(vi::AbstractVarInfo, t::AHMC.Transition)
3125
theta = tonamedtuple(vi)
3226
lp = getlogp(vi)
33-
return HMCTransition(theta, lp, t.stat)
27+
return Transition(theta, lp, t.stat)
3428
end
3529

36-
function metadata(t::HMCTransition)
37-
return merge((lp = t.lp,), t.stat)
38-
end
39-
40-
DynamicPPL.getlogp(t::HMCTransition) = t.lp
41-
4230
###
4331
### Hamiltonian Monte Carlo samplers.
4432
###
@@ -230,7 +218,7 @@ function DynamicPPL.initialstep(
230218
vi = setlogp!!(vi, log_density_old)
231219
end
232220

233-
transition = HMCTransition(vi, t)
221+
transition = Transition(vi, t)
234222
state = HMCState(vi, 1, kernel, hamiltonian, t.z, adaptor)
235223

236224
return transition, state
@@ -270,7 +258,7 @@ function AbstractMCMC.step(
270258
end
271259

272260
# Compute next transition and state.
273-
transition = HMCTransition(vi, t)
261+
transition = Transition(vi, t)
274262
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
275263

276264
return transition, newstate

test/inference/gibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@
7474
::Type{MCMCChains.Chains};
7575
kwargs...
7676
)
77-
samples isa Vector{<:Inference.GibbsTransition} ||
77+
samples isa Vector{<:Inference.Transition} ||
7878
error("incorrect transitions")
7979
return
8080
end
8181

8282
function callback(rng, model, sampler, sample, state, i; kwargs...)
83-
sample isa Inference.GibbsTransition || error("incorrect sample")
83+
sample isa Inference.Transition || error("incorrect sample")
8484
return
8585
end
8686

0 commit comments

Comments
 (0)