Skip to content

Commit 2b405d9

Browse files
committed
Fix variable order and name of map_accumulator!!
1 parent 68b974a commit 2b405d9

File tree

7 files changed

+58
-73
lines changed

7 files changed

+58
-73
lines changed

src/abstract_varinfo.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ end
239239
Update all the accumulators of `vi` by calling `accumulate_assume!!` on them.
240240
"""
241241
function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right)
242-
return map_accumulator!!(vi, accumulate_assume!!, val, logjac, vn, right)
242+
return map_accumulators!!(acc -> accumulate_assume!!(acc, val, logjac, vn, right), vi)
243243
end
244244

245245
"""
@@ -248,37 +248,35 @@ end
248248
Update all the accumulators of `vi` by calling `accumulate_observe!!` on them.
249249
"""
250250
function accumulate_observe!!(vi::AbstractVarInfo, right, left, vn)
251-
return map_accumulator!!(vi, accumulate_observe!!, right, left, vn)
251+
return map_accumulators!!(acc -> accumulate_observe!!(acc, right, left, vn), vi)
252252
end
253253

254254
"""
255-
map_accumulator!!(vi::AbstractVarInfo, func::Function, args...) where {accname}
255+
map_accumulators(vi::AbstractVarInfo, func::Function)
256256
257-
Update all accumulators of `vi` by calling `func(acc, args...)` on them and replacing
258-
them with the return values.
257+
Update all accumulators of `vi` by calling `func` on them and replacing them with the return
258+
values.
259259
"""
260-
function map_accumulator!!(vi::AbstractVarInfo, func::Function, args...)
261-
return setaccs!!(vi, map_accumulator!!(getaccs(vi), func, args...))
260+
function map_accumulators!!(func::Function, vi::AbstractVarInfo)
261+
return setaccs!!(vi, map(func, getaccs(vi)))
262262
end
263263

264264
"""
265-
map_accumulator!!(vi::AbstractVarInfo, ::Val{accname}, func::Function, args...) where {accname}
265+
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) where {accname}
266266
267-
Update the accumulator `accname` of `vi` by calling `func(acc, args...)` on and replacing
268-
it with the return value.
267+
Update the accumulator `accname` of `vi` by calling `func` on it and replacing it with the
268+
return value.
269269
"""
270-
function map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...)
271-
return setaccs!!(vi, map_accumulator!!(getaccs(vi), accname, func, args...))
270+
function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Val)
271+
return setaccs!!(vi, map_accumulator(func, getaccs(vi), accname))
272272
end
273273

274-
function map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...)
274+
function map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
275275
return error(
276276
"""
277-
The method
278-
map_accumulator!!(vi::AbstractVarInfo, accname::Symbol, func::Function, args...)
277+
The method map_accumulator!!(func::Function, vi::AbstractVarInfo, accname::Symbol)
279278
does not exist. For type stability reasons use
280-
map_accumulator!!(vi::AbstractVarInfo, accname::Val, func::Function, args...)
281-
instead.
279+
map_accumulator!!(func::Function, vi::AbstractVarInfo, ::Val{accname}) instead.
282280
"""
283281
)
284282
end
@@ -291,7 +289,7 @@ Add `logp` to the value of the log of the prior probability in `vi`.
291289
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getlogprior`](@ref), [`setlogprior!!`](@ref).
292290
"""
293291
function acclogprior!!(vi::AbstractVarInfo, logp)
294-
return map_accumulator!!(vi, Val(:LogPrior), +, LogPrior(logp))
292+
return map_accumulator!!(acc -> acc + LogPrior(logp), vi, Val(:LogPrior))
295293
end
296294

297295
"""
@@ -302,7 +300,7 @@ Add `logp` to the value of the log of the likelihood in `vi`.
302300
See also: [`accloglikelihood!!`](@ref), [`acclogp!!`](@ref), [`getloglikelihood`](@ref), [`setloglikelihood!!`](@ref).
303301
"""
304302
function accloglikelihood!!(vi::AbstractVarInfo, logp)
305-
return map_accumulator!!(vi, Val(:LogLikelihood), +, LogLikelihood(logp))
303+
return map_accumulator!!(acc -> acc + LogLikelihood(logp), vi, Val(:LogLikelihood))
306304
end
307305

308306
"""
@@ -326,10 +324,10 @@ Reset the values of the log probabilities (prior and likelihood) in `vi`
326324
"""
327325
function resetlogp!!(vi::AbstractVarInfo)
328326
if hasacc(vi, Val(:LogPrior))
329-
vi = map_accumulator!!(vi, Val(:LogPrior), zero)
327+
vi = map_accumulator!!(zero, vi, Val(:LogPrior))
330328
end
331329
if hasacc(vi, Val(:LogLikelihood))
332-
vi = map_accumulator!!(vi, Val(:LogLikelihood), zero)
330+
vi = map_accumulator!!(zero, vi, Val(:LogLikelihood))
333331
end
334332
return vi
335333
end

src/accumulators.jl

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ function combine end
8989
# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in
9090
# src/varinfo.jl.
9191
"""
92-
convert_eltype(acc::AbstractAccumulator, ::Type{T})
92+
convert_eltype(::Type{T}, acc::AbstractAccumulator)
9393
9494
Convert `acc` to use element type `T`.
9595
9696
What "element type" means depends on the type of `acc`. By default this function does
9797
nothing. Accumulator types that need to hold differentiable values, such as dual numbers
9898
used by various AD backends, should implement a method for this function.
9999
"""
100-
convert_eltype(acc::AbstractAccumulator, ::Type) = acc
100+
convert_eltype(::Type, acc::AbstractAccumulator) = acc
101101

102102
# END ABSTRACT ACCUMULATOR, BEGIN ACCUMULATOR TUPLE
103103

@@ -167,36 +167,25 @@ function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname}
167167
return at[accname]
168168
end
169169

170-
"""
171-
map_accumulator!!(at::AccumulatorTuple, func::Function, args...)
172-
173-
Update the accumulators in `at` by calling `func(acc, args...)` on them and replacing them
174-
with the return values.
175-
176-
Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the
177-
corresponding function for `AbstractVarInfo`.
178-
"""
179-
function map_accumulator!!(at::AccumulatorTuple, func::Function, args...)
180-
return AccumulatorTuple(map(acc -> func(acc, args...), at.nt))
170+
function Base.map(func::Function, at::AccumulatorTuple)
171+
return AccumulatorTuple(map(func, at.nt))
181172
end
182173

183174
"""
184-
map_accumulator!!(at::AccumulatorTuple, ::Val{accname}, func::Function, args...)
175+
map_accumulator(func::Function, at::AccumulatorTuple, ::Val{accname})
185176
186-
Update the accumulator with name `accname` in `at` by calling `func(acc, args...)` on it
187-
and replacing it with the return value.
177+
Update the accumulator with name `accname` in `at` by calling `func` on it.
188178
189-
Returns a new `AccumulatorTuple`. The `!!` in the name is for consistency with the
190-
corresponding function for `AbstractVarInfo`.
179+
Returns a new `AccumulatorTuple`.
191180
"""
192-
function map_accumulator!!(
193-
at::AccumulatorTuple, ::Val{accname}, func::Function, args...
181+
function map_accumulator(
182+
func::Function, at::AccumulatorTuple, ::Val{accname}
194183
) where {accname}
195184
# Would like to write this as
196185
# return Accessors.@set at.nt[accname] = func(at[accname], args...)
197186
# for readability, but that one isn't type stable due to
198187
# https://github.com/JuliaObjects/Accessors.jl/issues/198
199-
new_val = func(at[accname], args...)
188+
new_val = func(at[accname])
200189
new_nt = merge(at.nt, NamedTuple{(accname,)}((new_val,)))
201190
return AccumulatorTuple(new_nt)
202191
end
@@ -318,7 +307,7 @@ end
318307
# convert_eltype(::AbstractAccumulator, ::Type). This is because they are only used to
319308
# deal with dual number types of AD backends, which shouldn't concern NumProduce. This is
320309
# horribly hacky and should be fixed. See also comment in `unflatten` in `src/varinfo.jl`.
321-
convert_eltype(acc::LogPrior, ::Type{T}) where {T} = LogPrior(convert(T, acc.logp))
322-
function convert_eltype(acc::LogLikelihood, ::Type{T}) where {T}
310+
convert_eltype(::Type{T}, acc::LogPrior) where {T} = LogPrior(convert(T, acc.logp))
311+
function convert_eltype(::Type{T}, acc::LogLikelihood) where {T}
323312
return LogLikelihood(convert(T, acc.logp))
324313
end

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ function unflatten(svi::SimpleVarInfo, x::AbstractVector)
251251
# TODO(mhauru) See comment in unflatten in src/varinfo.jl for why this conversion is
252252
# required but undesireable.
253253
et = float_type_with_fallback(eltype(x))
254-
accs = map_accumulator!!(svi.accs, convert_eltype, et)
254+
accs = map(acc -> convert_eltype(et, acc), svi.accs)
255255
return SimpleVarInfo(vals, accs, svi.transformation)
256256
end
257257

src/threadsafe.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,17 @@ function getaccs(vi::ThreadSafeVarInfo)
4646
return AccumulatorTuple(map(anv -> getacc(vi, anv), accname_vals))
4747
end
4848

49-
# Calls to map_accumulator!! are thread-specific by default. For any use of them that should
50-
# _not_ be thread-specific a specific method has to be written.
51-
function map_accumulator!!(vi::ThreadSafeVarInfo, accname::Val, func::Function, args...)
49+
# Calls to map_accumulator(s)!! are thread-specific by default. For any use of them that
50+
# should _not_ be thread-specific a specific method has to be written.
51+
function map_accumulator!!(func::Function, vi::ThreadSafeVarInfo, accname::Val)
5252
tid = Threads.threadid()
53-
vi.accs_by_thread[tid] = map_accumulator!!(
54-
vi.accs_by_thread[tid], accname, func, args...
55-
)
53+
vi.accs_by_thread[tid] = map_accumulator(func, vi.accs_by_thread[tid], accname)
5654
return vi
5755
end
5856

59-
function map_accumulator!!(vi::ThreadSafeVarInfo, func::Function, args...)
57+
function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo)
6058
tid = Threads.threadid()
61-
vi.accs_by_thread[tid] = map_accumulator!!(vi.accs_by_thread[tid], func, args...)
59+
vi.accs_by_thread[tid] = map(func, vi.accs_by_thread[tid])
6260
return vi
6361
end
6462

@@ -186,9 +184,9 @@ end
186184
function resetlogp!!(vi::ThreadSafeVarInfo)
187185
vi = Accessors.@set vi.varinfo = resetlogp!!(vi.varinfo)
188186
for i in eachindex(vi.accs_by_thread)
189-
vi.accs_by_thread[i] = map_accumulator!!(vi.accs_by_thread[i], Val(:LogPrior), zero)
190-
vi.accs_by_thread[i] = map_accumulator!!(
191-
vi.accs_by_thread[i], Val(:LogLikelihood), zero
187+
vi.accs_by_thread[i] = map_accumulator(zero, vi.accs_by_thread[i], Val(:LogPrior))
188+
vi.accs_by_thread[i] = map_accumulator(
189+
zero, vi.accs_by_thread[i], Val(:LogLikelihood)
192190
)
193191
end
194192
return vi

src/transforming.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
4545
return _transform!!(NoTransformation(), DynamicTransformationContext{true}(), vi, model)
4646
end
4747

48-
function _transform(
48+
function _transform!!(
4949
t::AbstractTransformation,
5050
ctx::DynamicTransformationContext,
5151
vi::AbstractVarInfo,
@@ -54,8 +54,8 @@ function _transform(
5454
# To transform using DynamicTransformationContext, we evaluate the model, but we do not
5555
# need to use any accumulators other than LogPrior (which is affected by the Jacobian of
5656
# the transformation).
57-
accs = getaccs(vi.accs)
58-
has_logprior = hasacc(accs, Val(:LogPrior))
57+
accs = getaccs(vi)
58+
has_logprior = haskey(accs, Val(:LogPrior))
5959
if has_logprior
6060
old_logprior = getacc(accs, Val(:LogPrior))
6161
vi = setaccs!!(vi, (old_logprior,))

src/varinfo.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ function unflatten(vi::VarInfo, x::AbstractVector)
450450
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
451451
# plain ugly and hacky.
452452
et = float_type_with_fallback(eltype(x))
453-
accs = map_accumulator!!(deepcopy(vi.accs), convert_eltype, et)
453+
accs = map(acc -> convert_eltype(et, acc), deepcopy(getaccs(vi)))
454454
return VarInfo(md, accs)
455455
end
456456

@@ -1032,15 +1032,15 @@ set_num_produce!!(vi::VarInfo, n::Int) = setacc!!(vi, NumProduce(n))
10321032
10331033
Add 1 to `num_produce` in `vi`.
10341034
"""
1035-
increment_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), increment)
1035+
increment_num_produce!!(vi::VarInfo) = map_accumulator!!(increment, vi, Val(:NumProduce))
10361036

10371037
"""
10381038
reset_num_produce!!(vi::VarInfo)
10391039
10401040
Reset the value of `num_produce` the log of the joint probability of the observed data
10411041
and parameters sampled in `vi` to 0.
10421042
"""
1043-
reset_num_produce!!(vi::VarInfo) = map_accumulator!!(vi, Val(:NumProduce), zero)
1043+
reset_num_produce!!(vi::VarInfo) = map_accumulator!!(zero, vi, Val(:NumProduce))
10441044

10451045
# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple).
10461046
isempty(vi::VarInfo) = _isempty(vi.metadata)

test/accumulators.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using DynamicPPL:
1414
convert_eltype,
1515
getacc,
1616
increment,
17-
map_accumulator!!,
17+
map_accumulator,
1818
setacc!!,
1919
split
2020

@@ -66,8 +66,8 @@ using DynamicPPL:
6666
LogLikelihood{Float32}(1.0f0)
6767
@test convert(NumProduce{UInt8}, NumProduce(1)) == NumProduce{UInt8}(1)
6868

69-
@test convert_eltype(LogPrior(1.0), Float32) == LogPrior{Float32}(1.0f0)
70-
@test convert_eltype(LogLikelihood(1.0), Float32) ==
69+
@test convert_eltype(Float32, LogPrior(1.0)) == LogPrior{Float32}(1.0f0)
70+
@test convert_eltype(Float32, LogLikelihood(1.0)) ==
7171
LogLikelihood{Float32}(1.0f0)
7272
end
7373

@@ -137,23 +137,23 @@ using DynamicPPL:
137137
@test getacc(at_all64, Val(:LogPrior)) == lp_f64
138138
end
139139

140-
@testset "map_accumulator!!" begin
140+
@testset "map_accumulator(s)!!" begin
141141
# map over all accumulators
142142
accs = AccumulatorTuple(lp_f32, ll_f32)
143-
@test map_accumulator!!(accs, zero) ==
144-
AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0))
143+
@test map(zero, accs) == AccumulatorTuple(LogPrior(0.0f0), LogLikelihood(0.0f0))
145144
# Test that the original wasn't modified.
146145
@test accs == AccumulatorTuple(lp_f32, ll_f32)
147146

148-
# A map with extra arguments that changes the types of the accumulators.
149-
@test map_accumulator!!(accs, convert_eltype, Float64) ==
147+
# A map with a closure that changes the types of the accumulators.
148+
@test map(acc -> convert_eltype(Float64, acc), accs) ==
150149
AccumulatorTuple(LogPrior(1.0), LogLikelihood(1.0))
151150

152151
# only apply to a particular accumulator
153-
@test map_accumulator!!(accs, Val(:LogLikelihood), zero) ==
152+
@test map_accumulator(zero, accs, Val(:LogLikelihood)) ==
154153
AccumulatorTuple(lp_f32, LogLikelihood(0.0f0))
155-
@test map_accumulator!!(accs, Val(:LogLikelihood), convert_eltype, Float64) ==
156-
AccumulatorTuple(lp_f32, LogLikelihood(1.0))
154+
@test map_accumulator(
155+
acc -> convert_eltype(Float64, acc), accs, Val(:LogLikelihood)
156+
) == AccumulatorTuple(lp_f32, LogLikelihood(1.0))
157157
end
158158
end
159159
end

0 commit comments

Comments
 (0)