Skip to content

Commit d3ed55b

Browse files
committed
Use copy rather than deepcopy for accumulators
1 parent 4dd000c commit d3ed55b

File tree

7 files changed

+30
-11
lines changed

7 files changed

+30
-11
lines changed

src/accumulators.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ An accumulator type `T <: AbstractAccumulator` must implement the following meth
1313
- `accumulator_name(acc::T)` or `accumulator_name(::Type{T})`
1414
- `accumulate_observe!!(acc::T, right, left, vn)`
1515
- `accumulate_assume!!(acc::T, val, logjac, vn, right)`
16+
- `Base.copy(acc::T)`
1617
1718
To be able to work with multi-threading, it should also implement:
1819
- `split(acc::T)`
@@ -138,6 +139,7 @@ end
138139
Base.keys(at::AccumulatorTuple) = keys(at.nt)
139140
Base.:(==)(at1::AccumulatorTuple, at2::AccumulatorTuple) = at1.nt == at2.nt
140141
Base.hash(at::AccumulatorTuple, h::UInt) = Base.hash((AccumulatorTuple, at.nt), h)
142+
Base.copy(at::AccumulatorTuple) = AccumulatorTuple(map(copy, at.nt))
141143

142144
function Base.convert(::Type{AccumulatorTuple{N,T}}, accs::AccumulatorTuple{N}) where {N,T}
143145
return AccumulatorTuple(convert(T, accs.nt))

src/default_accumulators.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,16 @@ end
6666
Create a new `VariableOrderAccumulator` accumulator with the number of observations set to n
6767
"""
6868
VariableOrderAccumulator{T}(n=zero(T)) where {T<:Integer} =
69-
VariableOrderAccumulator(convert(T, n), OrderedDict{VarName,T}())
69+
VariableOrderAccumulator(convert(T, n), Dict{VarName,T}())
7070
VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n)
7171
VariableOrderAccumulator() = VariableOrderAccumulator{Int}()
7272

73+
Base.copy(acc::LogPriorAccumulator) = acc
74+
Base.copy(acc::LogLikelihoodAccumulator) = acc
75+
function Base.copy(acc::VariableOrderAccumulator)
76+
return VariableOrderAccumulator(acc.num_produce, copy(acc.order))
77+
end
78+
7379
function Base.show(io::IO, acc::LogPriorAccumulator)
7480
return print(io, "LogPriorAccumulator($(repr(acc.logp)))")
7581
end

src/extract_priors.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ end
44

55
PriorDistributionAccumulator() = PriorDistributionAccumulator(OrderedDict{VarName,Any}())
66

7+
function Base.copy(acc::PriorDistributionAccumulator)
8+
return PriorDistributionAccumulator(copy(acc.priors))
9+
end
10+
711
accumulator_name(::PriorDistributionAccumulator) = :PriorDistributionAccumulator
812

913
split(acc::PriorDistributionAccumulator) = PriorDistributionAccumulator(empty(acc.priors))

src/pointwise_logdensities.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ function PointwiseLogProbAccumulator{whichlogprob,KeyType}() where {whichlogprob
3131
return PointwiseLogProbAccumulator{whichlogprob,KeyType,typeof(logps)}(logps)
3232
end
3333

34+
function Base.copy(acc::PointwiseLogProbAccumulator{whichlogprob}) where {whichlogprob}
35+
return PointwiseLogProbAccumulator{whichlogprob}(copy(acc.logps))
36+
end
37+
3438
function Base.push!(acc::PointwiseLogProbAccumulator, vn, logp)
3539
logps = acc.logps
3640
# The last(fieldtypes(eltype(...))) gets the type of the values, rather than the keys.

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ end
248248
# Constructor from `VarInfo`.
249249
function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D}
250250
values = values_as(vi, D)
251-
return SimpleVarInfo(values, deepcopy(getaccs(vi)))
251+
return SimpleVarInfo(values, copy(getaccs(vi)))
252252
end
253253
function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D}
254254
values = values_as(vi, D)
@@ -447,7 +447,7 @@ _subset(x::VarNamedVector, vns) = subset(x, vns)
447447
# `merge`
448448
function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
449449
values = merge(varinfo_left.values, varinfo_right.values)
450-
accs = deepcopy(getaccs(varinfo_right))
450+
accs = copy(getaccs(varinfo_right))
451451
transformation = merge_transformations(
452452
varinfo_left.transformation, varinfo_right.transformation
453453
)

src/values_as_in_model.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ function ValuesAsInModelAccumulator(include_colon_eq)
2020
return ValuesAsInModelAccumulator(OrderedDict(), include_colon_eq)
2121
end
2222

23+
function Base.copy(acc::ValuesAsInModelAccumulator)
24+
return ValuesAsInModelAccumulator(copy(acc.values), acc.include_colon_eq)
25+
end
26+
2327
accumulator_name(::Type{<:ValuesAsInModelAccumulator}) = :ValuesAsInModel
2428

2529
function split(acc::ValuesAsInModelAccumulator)

src/varinfo.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ function typed_varinfo(vi::UntypedVarInfo)
294294
)
295295
end
296296
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
297-
return VarInfo(nt, deepcopy(vi.accs))
297+
return VarInfo(nt, copy(vi.accs))
298298
end
299299
function typed_varinfo(vi::NTVarInfo)
300300
# This function preserves the behaviour of typed_varinfo(vi) where vi is
@@ -355,7 +355,7 @@ single `VarNamedVector` as its metadata field.
355355
"""
356356
function untyped_vector_varinfo(vi::UntypedVarInfo)
357357
md = metadata_to_varnamedvector(vi.metadata)
358-
return VarInfo(md, deepcopy(vi.accs))
358+
return VarInfo(md, copy(vi.accs))
359359
end
360360
function untyped_vector_varinfo(
361361
rng::Random.AbstractRNG,
@@ -398,12 +398,12 @@ NamedTuple of `VarNamedVector`s as its metadata field.
398398
"""
399399
function typed_vector_varinfo(vi::NTVarInfo)
400400
md = map(metadata_to_varnamedvector, vi.metadata)
401-
return VarInfo(md, deepcopy(vi.accs))
401+
return VarInfo(md, copy(vi.accs))
402402
end
403403
function typed_vector_varinfo(vi::UntypedVectorVarInfo)
404404
new_metas = group_by_symbol(vi.metadata)
405405
nt = NamedTuple(new_metas)
406-
return VarInfo(nt, deepcopy(vi.accs))
406+
return VarInfo(nt, copy(vi.accs))
407407
end
408408
function typed_vector_varinfo(
409409
rng::Random.AbstractRNG,
@@ -455,8 +455,7 @@ function unflatten(vi::VarInfo, x::AbstractVector)
455455
# convert to into an intermediate variable makes this unstable (constant propagation)
456456
# fails. Take care when editing.
457457
accs = map(
458-
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc),
459-
deepcopy(getaccs(vi)),
458+
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi))
460459
)
461460
return VarInfo(md, accs)
462461
end
@@ -538,7 +537,7 @@ end
538537

539538
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
540539
metadata = subset(varinfo.metadata, vns)
541-
return VarInfo(metadata, deepcopy(varinfo.accs))
540+
return VarInfo(metadata, copy(varinfo.accs))
542541
end
543542

544543
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
@@ -619,7 +618,7 @@ end
619618

620619
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
621620
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
622-
return VarInfo(metadata, deepcopy(varinfo_right.accs))
621+
return VarInfo(metadata, copy(varinfo_right.accs))
623622
end
624623

625624
function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)

0 commit comments

Comments
 (0)