Skip to content

Commit fb09acc

Browse files
committed
Introduce default_accumulators()
1 parent c5e2a6b commit fb09acc

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/default_accumulators.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,11 @@ end
142142
function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T}
143143
return LogLikelihoodAccumulator(convert(T, acc.logp))
144144
end
145+
146+
function default_accumulators()
147+
return AccumulatorTuple(
148+
LogPriorAccumulator{LogProbType}(),
149+
LogLikelihoodAccumulator{LogProbType}(),
150+
NumProduceAccumulator{Int}(),
151+
)
152+
end

src/simple_varinfo.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ Evaluation in transformed space of course also works:
125125
126126
```jldoctest simplevarinfo-general
127127
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
128-
Transformed SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0)))
128+
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
129129
130130
julia> # (✓) Positive probability mass on negative numbers!
131131
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
132132
-1.3678794411714423
133133
134134
julia> # While if we forget to indicate that it's transformed:
135135
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
136-
SimpleVarInfo((x = -1.0,), (LogLikelihood = LogLikelihoodAccumulator(0.0), LogPrior = LogPriorAccumulator(0.0)))
136+
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), NumProduce = NumProduceAccumulator(0)))
137137
138138
julia> # (✓) No probability mass on negative numbers!
139139
getlogjoint(last(DynamicPPL.evaluate!!(m, vi, DynamicPPL.DefaultContext())))
@@ -204,9 +204,7 @@ function SimpleVarInfo(values, accs)
204204
return SimpleVarInfo(values, accs, NoTransformation())
205205
end
206206
function SimpleVarInfo{T}(values) where {T<:Real}
207-
return SimpleVarInfo(
208-
values, AccumulatorTuple(LogLikelihoodAccumulator{T}(), LogPriorAccumulator{T}())
209-
)
207+
return SimpleVarInfo(values, default_accumulators())
210208
end
211209
function SimpleVarInfo(values)
212210
return SimpleVarInfo{LogProbType}(values)

src/varinfo.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,7 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo
102102
accs::Accs
103103
end
104104
function VarInfo(meta=Metadata())
105-
return VarInfo(
106-
meta,
107-
AccumulatorTuple(
108-
LogPriorAccumulator{LogProbType}(),
109-
LogLikelihoodAccumulator{LogProbType}(),
110-
NumProduceAccumulator{Int}(),
111-
),
112-
)
105+
return VarInfo(meta, default_accumulators())
113106
end
114107

115108
"""

0 commit comments

Comments
 (0)