Skip to content

Commit 55be793

Browse files
committed
Turn NumProduceAccumulator into VariableOrderAccumulator
1 parent d4ef1f2 commit 55be793

File tree

10 files changed

+134
-118
lines changed

10 files changed

+134
-118
lines changed

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ get_num_produce
331331
set_num_produce!!
332332
increment_num_produce!!
333333
reset_num_produce!!
334-
setorder!
334+
setorder!!
335335
set_retained_vns_del!
336336
```
337337

@@ -358,7 +358,7 @@ DynamicPPL provides the following default accumulators.
358358
```@docs
359359
LogPriorAccumulator
360360
LogLikelihoodAccumulator
361-
NumProduceAccumulator
361+
VariableOrderAccumulator
362362
```
363363

364364
### Common API

src/DynamicPPL.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export AbstractVarInfo,
5050
AbstractAccumulator,
5151
LogLikelihoodAccumulator,
5252
LogPriorAccumulator,
53-
NumProduceAccumulator,
53+
VariableOrderAccumulator,
5454
push!!,
5555
empty!!,
5656
subset,
@@ -73,7 +73,7 @@ export AbstractVarInfo,
7373
is_flagged,
7474
set_flag!,
7575
unset_flag!,
76-
setorder!,
76+
setorder!!,
7777
istrans,
7878
link,
7979
link!!,

src/abstract_varinfo.jl

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,24 @@ function resetlogp!!(vi::AbstractVarInfo)
374374
return vi
375375
end
376376

377+
"""
378+
setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
379+
380+
Set the `order` of `vn` in `vi` to `index`, where `order` is the number of `observe
381+
statements run before sampling `vn`.
382+
"""
383+
function setorder!!(vi::AbstractVarInfo, vn::VarName, index::Integer)
384+
return map_accumulator!!(acc -> (acc.order[vn] = index; acc), vi, Val(:VariableOrder))
385+
end
386+
387+
"""
388+
getorder(vi::VarInfo, vn::VarName)
389+
390+
Get the `order` of `vn` in `vi`, where `order` is the number of `observe` statements
391+
run before sampling `vn`.
392+
"""
393+
getorder(vi::AbstractVarInfo, vn::VarName) = getacc(vi, Val(:VariableOrder)).order[vn]
394+
377395
# Variables and their realizations.
378396
@doc """
379397
keys(vi::AbstractVarInfo)
@@ -972,29 +990,37 @@ end
972990
973991
Return the `num_produce` of `vi`.
974992
"""
975-
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:NumProduce)).num
993+
get_num_produce(vi::AbstractVarInfo) = getacc(vi, Val(:VariableOrder)).num_produce
976994

977995
"""
978996
set_num_produce!!(vi::AbstractVarInfo, n::Int)
979997
980998
Set the `num_produce` field of `vi` to `n`.
981999
"""
982-
set_num_produce!!(vi::AbstractVarInfo, n::Int) = setacc!!(vi, NumProduceAccumulator(n))
1000+
function set_num_produce!!(vi::AbstractVarInfo, n::Integer)
1001+
if hasacc(vi, Val(:VariableOrder))
1002+
acc = getacc(vi, Val(:VariableOrder))
1003+
acc = VariableOrderAccumulator(n, acc.order)
1004+
else
1005+
acc = VariableOrderAccumulator(n)
1006+
end
1007+
return setacc!!(vi, acc)
1008+
end
9831009

9841010
"""
9851011
increment_num_produce!!(vi::AbstractVarInfo)
9861012
9871013
Add 1 to `num_produce` in `vi`.
9881014
"""
9891015
increment_num_produce!!(vi::AbstractVarInfo) =
990-
map_accumulator!!(increment, vi, Val(:NumProduce))
1016+
map_accumulator!!(increment, vi, Val(:VariableOrder))
9911017

9921018
"""
9931019
reset_num_produce!!(vi::AbstractVarInfo)
9941020
9951021
Reset the value of `num_produce` in `vi` to 0.
9961022
"""
997-
reset_num_produce!!(vi::AbstractVarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
1023+
reset_num_produce!!(vi::AbstractVarInfo) = set_num_produce!!(vi, zero(get_num_produce(vi)))
9981024

9991025
"""
10001026
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist])

src/context_implementations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ function assume(
184184
f = to_maybe_linked_internal_transform(vi, vn, dist)
185185
# TODO(mhauru) This should probably be call a function called setindex_internal!
186186
vi = BangBang.setindex!!(vi, f(r), vn)
187-
setorder!(vi, vn, get_num_produce(vi))
188187
else
189188
# Otherwise we just extract it.
190189
r = vi[vn, dist]

src/default_accumulators.jl

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,52 +41,62 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
4141
LogLikelihoodAccumulator() = LogLikelihoodAccumulator{LogProbType}()
4242

4343
"""
44-
NumProduceAccumulator{T} <: AbstractAccumulator
44+
VariableOrderAccumulator{T} <: AbstractAccumulator
4545
46-
An accumulator that tracks the number of observations during model execution.
46+
An accumulator that tracks the order of variables in a `VarInfo`.
47+
48+
This doesn't track the full ordering, but rather how many observations have taken place
49+
before the assume statement for each variable. This is needed for particle methods, where
50+
the model is segmented into parts by each observation, and we need to know which part each
51+
assume statement is in.
4752
4853
# Fields
4954
$(TYPEDFIELDS)
5055
"""
51-
struct NumProduceAccumulator{T<:Integer} <: AbstractAccumulator
56+
struct VariableOrderAccumulator{Eltype<:Integer,VNType<:VarName} <: AbstractAccumulator
5257
"the number of observations"
53-
num::T
58+
num_produce::Eltype
59+
"mapping of variable names to their order in the model"
60+
order::OrderedDict{VNType, Eltype}
5461
end
5562

5663
"""
57-
NumProduceAccumulator{T<:Integer}()
64+
VariableOrderAccumulator{T<:Integer}(n=zero(T))
5865
59-
Create a new `NumProduceAccumulator` accumulator with the number of observations initialized to zero.
66+
Create a new `VariableOrderAccumulator` accumulator with the number of observations set to n
6067
"""
61-
NumProduceAccumulator{T}() where {T<:Integer} = NumProduceAccumulator(zero(T))
62-
NumProduceAccumulator() = NumProduceAccumulator{Int}()
68+
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} = VariableOrderAccumulator(convert(T, n), OrderedDict{VarName, T}())
69+
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
70+
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
6371

6472
function Base.show(io::IO, acc::LogPriorAccumulator)
6573
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
6674
end
6775
function Base.show(io::IO, acc::LogLikelihoodAccumulator)
6876
return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))")
6977
end
70-
function Base.show(io::IO, acc::NumProduceAccumulator)
71-
return print(io, "NumProduceAccumulator($(repr(acc.num)))")
78+
function Base.show(io::IO, acc::VariableOrderAccumulator)
79+
return print(io, "VariableOrderAccumulator($(repr(acc.num_produce)), $(repr(acc.order)))")
7280
end
7381

7482
accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior
7583
accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood
76-
accumulator_name(::Type{<:NumProduceAccumulator}) = :NumProduce
84+
accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder
7785

7886
split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T))
7987
split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T))
80-
split(acc::NumProduceAccumulator) = acc
88+
split(acc::VariableOrderAccumulator) = acc
8189

8290
function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator)
8391
return LogPriorAccumulator(acc.logp + acc2.logp)
8492
end
8593
function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
8694
return LogLikelihoodAccumulator(acc.logp + acc2.logp)
8795
end
88-
function combine(acc::NumProduceAccumulator, acc2::NumProduceAccumulator)
89-
return NumProduceAccumulator(max(acc.num, acc2.num))
96+
function combine(acc::VariableOrderAccumulator, acc2::VariableOrderAccumulator)
97+
# Note that assumptions are not allowed within in parallelised blocks, and thus the
98+
# dictionaries should be identical.
99+
return VariableOrderAccumulator(max(acc.num_produce, acc2.num_produce), merge(acc.order, acc2.order))
90100
end
91101

92102
function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator)
@@ -95,11 +105,10 @@ end
95105
function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator)
96106
return LogLikelihoodAccumulator(acc1.logp + acc2.logp)
97107
end
98-
increment(acc::NumProduceAccumulator) = NumProduceAccumulator(acc.num + oneunit(acc.num))
108+
increment(acc::VariableOrderAccumulator) = VariableOrderAccumulator(acc.num_produce + oneunit(acc.num_produce), acc.order)
99109

100110
Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp))
101111
Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp))
102-
Base.zero(acc::NumProduceAccumulator) = NumProduceAccumulator(zero(acc.num))
103112

104113
function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right)
105114
return acc + LogPriorAccumulator(logpdf(right, val) + logjac)
@@ -114,8 +123,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
114123
return acc + LogLikelihoodAccumulator(Distributions.loglikelihood(right, left))
115124
end
116125

117-
accumulate_assume!!(acc::NumProduceAccumulator, val, logjac, vn, right) = acc
118-
accumulate_observe!!(acc::NumProduceAccumulator, right, left, vn) = increment(acc)
126+
function accumulate_assume!!(acc::VariableOrderAccumulator, val, logjac, vn, right)
127+
acc.order[vn] = acc.num_produce
128+
return acc
129+
end
130+
accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment(acc)
119131

120132
function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T}
121133
return LogPriorAccumulator(convert(T, acc.logp))
@@ -126,15 +138,19 @@ function Base.convert(
126138
return LogLikelihoodAccumulator(convert(T, acc.logp))
127139
end
128140
function Base.convert(
129-
::Type{NumProduceAccumulator{T}}, acc::NumProduceAccumulator
130-
) where {T}
131-
return NumProduceAccumulator(convert(T, acc.num))
141+
::Type{VariableOrderAccumulator{ElType, VnType}}, acc::VariableOrderAccumulator
142+
) where {ElType, VnType}
143+
order = OrderedDict{VnType, ElType}()
144+
for (k, v) in acc.order
145+
order[convert(VnType, k)] = convert(ElType, v)
146+
end
147+
return VariableOrderAccumulator(convert(ElType, acc.num_produce), order)
132148
end
133149

134150
# TODO(mhauru)
135-
# We ignore the convert_eltype calls for NumProduceAccumulator, by letting them fallback on
151+
# We ignore the convert_eltype calls for VariableOrderAccumulator, by letting them fallback on
136152
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137-
# deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator. This is
153+
# deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator. This is
138154
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139155
function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T}
140156
return LogPriorAccumulator(convert(T, acc.logp))
@@ -149,6 +165,6 @@ function default_accumulators(
149165
return AccumulatorTuple(
150166
LogPriorAccumulator{FloatT}(),
151167
LogLikelihoodAccumulator{FloatT}(),
152-
NumProduceAccumulator{IntT}(),
168+
VariableOrderAccumulator{IntT}(),
153169
)
154170
end

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ Evaluation in transformed space of course also works:
125125
126126
```jldoctest simplevarinfo-general
127127
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
128-
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
128+
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, OrderedDict{VarName, Int64}())))
129129
130130
julia> # (✓) Positive probability mass on negative numbers!
131131
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
132132
-1.3678794411714423
133133
134134
julia> # While if we forget to indicate that it's transformed:
135135
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
136-
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
136+
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, OrderedDict{VarName, Int64}())))
137137
138138
julia> # (✓) No probability mass on negative numbers!
139139
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))

src/threadsafe.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
8080

81-
setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index)
81+
setorder!!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = ThreadSafeVarInfo(setorder!!(vi.varinfo, vn, index), vi.accs_by_thread)
8282
setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
8383

8484
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)

0 commit comments

Comments
 (0)