@@ -41,52 +41,62 @@ LogLikelihoodAccumulator{T}() where {T<:Real} = LogLikelihoodAccumulator(zero(T)
41
41
LogLikelihoodAccumulator () = LogLikelihoodAccumulator {LogProbType} ()
42
42
43
43
"""
44
- NumProduceAccumulator {T} <: AbstractAccumulator
44
+ VariableOrderAccumulator {T} <: AbstractAccumulator
45
45
46
- An accumulator that tracks the number of observations during model execution.
46
+ An accumulator that tracks the order of variables in a `VarInfo`.
47
+
48
+ This doesn't track the full ordering, but rather how many observations have taken place
49
+ before the assume statement for each variable. This is needed for particle methods, where
50
+ the model is segmented into parts by each observation, and we need to know which part each
51
+ assume statement is in.
47
52
48
53
# Fields
49
54
$(TYPEDFIELDS)
50
55
"""
51
- struct NumProduceAccumulator{T <: Integer } <: AbstractAccumulator
56
+ struct VariableOrderAccumulator{Eltype <: Integer ,VNType <: VarName } <: AbstractAccumulator
52
57
" the number of observations"
53
- num:: T
58
+ num_produce:: Eltype
59
+ " mapping of variable names to their order in the model"
60
+ order:: OrderedDict{VNType, Eltype}
54
61
end
55
62
56
63
"""
57
- NumProduceAccumulator {T<:Integer}()
64
+ VariableOrderAccumulator {T<:Integer}(n=zero(T) )
58
65
59
- Create a new `NumProduceAccumulator ` accumulator with the number of observations initialized to zero.
66
+ Create a new `VariableOrderAccumulator ` accumulator with the number of observations set to n
60
67
"""
61
- NumProduceAccumulator {T} () where {T<: Integer } = NumProduceAccumulator (zero (T))
62
- NumProduceAccumulator () = NumProduceAccumulator {Int} ()
68
+ VariableOrderAccumulator {T} (n= zero (T)) where {T<: Integer } = VariableOrderAccumulator (convert (T, n), OrderedDict {VarName, T} ())
69
+ VariableOrderAccumulator (n) = VariableOrderAccumulator {typeof(n)} (n)
70
+ VariableOrderAccumulator () = VariableOrderAccumulator {Int} ()
63
71
64
72
function Base. show (io:: IO , acc:: LogPriorAccumulator )
65
73
return print (io, " LogPriorAccumulator($(repr (acc. logp)) )" )
66
74
end
67
75
function Base. show (io:: IO , acc:: LogLikelihoodAccumulator )
68
76
return print (io, " LogLikelihoodAccumulator($(repr (acc. logp)) )" )
69
77
end
70
- function Base. show (io:: IO , acc:: NumProduceAccumulator )
71
- return print (io, " NumProduceAccumulator ($(repr (acc. num )) )" )
78
+ function Base. show (io:: IO , acc:: VariableOrderAccumulator )
79
+ return print (io, " VariableOrderAccumulator ($(repr (acc. num_produce)) , $( repr (acc . order )) )" )
72
80
end
73
81
74
82
accumulator_name (:: Type{<:LogPriorAccumulator} ) = :LogPrior
75
83
accumulator_name (:: Type{<:LogLikelihoodAccumulator} ) = :LogLikelihood
76
- accumulator_name (:: Type{<:NumProduceAccumulator } ) = :NumProduce
84
+ accumulator_name (:: Type{<:VariableOrderAccumulator } ) = :VariableOrder
77
85
78
86
split (:: LogPriorAccumulator{T} ) where {T} = LogPriorAccumulator (zero (T))
79
87
split (:: LogLikelihoodAccumulator{T} ) where {T} = LogLikelihoodAccumulator (zero (T))
80
- split (acc:: NumProduceAccumulator ) = acc
88
+ split (acc:: VariableOrderAccumulator ) = acc
81
89
82
90
function combine (acc:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
83
91
return LogPriorAccumulator (acc. logp + acc2. logp)
84
92
end
85
93
function combine (acc:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
86
94
return LogLikelihoodAccumulator (acc. logp + acc2. logp)
87
95
end
88
- function combine (acc:: NumProduceAccumulator , acc2:: NumProduceAccumulator )
89
- return NumProduceAccumulator (max (acc. num, acc2. num))
96
+ function combine (acc:: VariableOrderAccumulator , acc2:: VariableOrderAccumulator )
97
+ # Note that assumptions are not allowed within in parallelised blocks, and thus the
98
+ # dictionaries should be identical.
99
+ return VariableOrderAccumulator (max (acc. num_produce, acc2. num_produce), merge (acc. order, acc2. order))
90
100
end
91
101
92
102
function Base.:+ (acc1:: LogPriorAccumulator , acc2:: LogPriorAccumulator )
95
105
function Base.:+ (acc1:: LogLikelihoodAccumulator , acc2:: LogLikelihoodAccumulator )
96
106
return LogLikelihoodAccumulator (acc1. logp + acc2. logp)
97
107
end
98
- increment (acc:: NumProduceAccumulator ) = NumProduceAccumulator (acc. num + oneunit (acc. num) )
108
+ increment (acc:: VariableOrderAccumulator ) = VariableOrderAccumulator (acc. num_produce + oneunit (acc. num_produce), acc . order )
99
109
100
110
Base. zero (acc:: LogPriorAccumulator ) = LogPriorAccumulator (zero (acc. logp))
101
111
Base. zero (acc:: LogLikelihoodAccumulator ) = LogLikelihoodAccumulator (zero (acc. logp))
102
- Base. zero (acc:: NumProduceAccumulator ) = NumProduceAccumulator (zero (acc. num))
103
112
104
113
function accumulate_assume!! (acc:: LogPriorAccumulator , val, logjac, vn, right)
105
114
return acc + LogPriorAccumulator (logpdf (right, val) + logjac)
@@ -114,8 +123,11 @@ function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn)
114
123
return acc + LogLikelihoodAccumulator (Distributions. loglikelihood (right, left))
115
124
end
116
125
117
- accumulate_assume!! (acc:: NumProduceAccumulator , val, logjac, vn, right) = acc
118
- accumulate_observe!! (acc:: NumProduceAccumulator , right, left, vn) = increment (acc)
126
+ function accumulate_assume!! (acc:: VariableOrderAccumulator , val, logjac, vn, right)
127
+ acc. order[vn] = acc. num_produce
128
+ return acc
129
+ end
130
+ accumulate_observe!! (acc:: VariableOrderAccumulator , right, left, vn) = increment (acc)
119
131
120
132
function Base. convert (:: Type{LogPriorAccumulator{T}} , acc:: LogPriorAccumulator ) where {T}
121
133
return LogPriorAccumulator (convert (T, acc. logp))
@@ -126,15 +138,19 @@ function Base.convert(
126
138
return LogLikelihoodAccumulator (convert (T, acc. logp))
127
139
end
128
140
function Base. convert (
129
- :: Type{NumProduceAccumulator{T}} , acc:: NumProduceAccumulator
130
- ) where {T}
131
- return NumProduceAccumulator (convert (T, acc. num))
141
+ :: Type{VariableOrderAccumulator{ElType, VnType}} , acc:: VariableOrderAccumulator
142
+ ) where {ElType, VnType}
143
+ order = OrderedDict {VnType, ElType} ()
144
+ for (k, v) in acc. order
145
+ order[convert (VnType, k)] = convert (ElType, v)
146
+ end
147
+ return VariableOrderAccumulator (convert (ElType, acc. num_produce), order)
132
148
end
133
149
134
150
# TODO (mhauru)
135
- # We ignore the convert_eltype calls for NumProduceAccumulator , by letting them fallback on
151
+ # We ignore the convert_eltype calls for VariableOrderAccumulator , by letting them fallback on
136
152
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
137
- # deal with dual number types of AD backends, which shouldn't concern NumProduceAccumulator . This is
153
+ # deal with dual number types of AD backends, which shouldn't concern VariableOrderAccumulator . This is
138
154
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
139
155
function convert_eltype (:: Type{T} , acc:: LogPriorAccumulator ) where {T}
140
156
return LogPriorAccumulator (convert (T, acc. logp))
@@ -149,6 +165,6 @@ function default_accumulators(
149
165
return AccumulatorTuple (
150
166
LogPriorAccumulator {FloatT} (),
151
167
LogLikelihoodAccumulator {FloatT} (),
152
- NumProduceAccumulator {IntT} (),
168
+ VariableOrderAccumulator {IntT} (),
153
169
)
154
170
end
0 commit comments