Skip to content

Commit 3f195e5

Browse files
committed
Various small fixes
1 parent e1b70e0 commit 3f195e5

File tree

4 files changed

+11
-16
lines changed

4 files changed

+11
-16
lines changed

src/abstract_varinfo.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,13 @@ end
105105

106106
"""
107107
setaccs!!(vi::AbstractVarInfo, accs::AccumulatorTuple)
108-
setaccs!!(vi::AbstractVarInfo, accs::AbstractAccumulator...)
108+
setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator} where {N})
109109
110110
Update the `AccumulatorTuple` of `vi` to `accs`, mutating if it makes sense.
111111
112112
`setaccs!!(vi:AbstractVarInfo, accs::AccumulatorTuple) should be implemented by each subtype
113113
of `AbstractVarInfo`.
114114
"""
115-
function setaccs!! end
116-
117115
function setaccs!!(vi::AbstractVarInfo, accs::NTuple{N,AbstractAccumulator}) where {N}
118116
return setaccs!!(vi, AccumulatorTuple(accs))
119117
end

src/varinfo.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ function typed_varinfo(vi::UntypedVarInfo)
292292
)
293293
end
294294
nt = NamedTuple{syms_tuple}(Tuple(new_metas))
295-
return VarInfo(nt, vi.accs)
295+
return VarInfo(nt, deepcopy(vi.accs))
296296
end
297297
function typed_varinfo(vi::NTVarInfo)
298298
# This function preserves the behaviour of typed_varinfo(vi) where vi is
@@ -353,7 +353,7 @@ single `VarNamedVector` as its metadata field.
353353
"""
354354
function untyped_vector_varinfo(vi::UntypedVarInfo)
355355
md = metadata_to_varnamedvector(vi.metadata)
356-
return VarInfo(md, vi.accs)
356+
return VarInfo(md, deepcopy(vi.accs))
357357
end
358358
function untyped_vector_varinfo(
359359
rng::Random.AbstractRNG,
@@ -396,12 +396,12 @@ NamedTuple of `VarNamedVector`s as its metadata field.
396396
"""
397397
function typed_vector_varinfo(vi::NTVarInfo)
398398
md = map(metadata_to_varnamedvector, vi.metadata)
399-
return VarInfo(md, vi.accs)
399+
return VarInfo(md, deepcopy(vi.accs))
400400
end
401401
function typed_vector_varinfo(vi::UntypedVectorVarInfo)
402402
new_metas = group_by_symbol(vi.metadata)
403403
nt = NamedTuple(new_metas)
404-
return VarInfo(nt, vi.accs)
404+
return VarInfo(nt, deepcopy(vi.accs))
405405
end
406406
function typed_vector_varinfo(
407407
rng::Random.AbstractRNG,
@@ -450,7 +450,7 @@ function unflatten(vi::VarInfo, x::AbstractVector)
450450
# messes with cases like using Float32 of logprobs and Float64 for x. Also, this is just
451451
# plain ugly and hacky.
452452
et = float_type_with_fallback(eltype(x))
453-
accs = map_accumulator!!(vi.accs, convert_eltype, et)
453+
accs = map_accumulator!!(deepcopy(vi.accs), convert_eltype, et)
454454
return VarInfo(md, accs)
455455
end
456456

@@ -533,7 +533,7 @@ end
533533

534534
function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName})
535535
metadata = subset(varinfo.metadata, vns)
536-
return VarInfo(metadata, varinfo.accs)
536+
return VarInfo(metadata, deepcopy(varinfo.accs))
537537
end
538538

539539
function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName})
@@ -622,7 +622,7 @@ end
622622

623623
function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo)
624624
metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata)
625-
return VarInfo(metadata, varinfo_right.accs)
625+
return VarInfo(metadata, deepcopy(varinfo_right.accs))
626626
end
627627

628628
function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector)

test/contexts.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown()
4646
Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
4747

4848
@testset "contexts.jl" begin
49-
child_contexts = Dict(:default => DefaultContext())
50-
51-
parent_contexts = Dict(
49+
contexts = Dict(
50+
:default => DefaultContext(),
5251
:testparent => DynamicPPL.TestUtils.TestParentContext(DefaultContext()),
5352
:sampling => SamplingContext(),
5453
:prefix => PrefixContext(@varname(x)),
@@ -63,8 +62,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
6362
:condition4 => ConditionContext((x=[1.0, missing],)),
6463
)
6564

66-
contexts = merge(child_contexts, parent_contexts)
67-
6865
@testset "$(name)" for (name, context) in contexts
6966
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
7067
DynamicPPL.TestUtils.test_context(context, model)

test/pointwise_logdensities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@testset "logdensities_likelihoods.jl" begin
2-
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS[1:1]
2+
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
33
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
44

55
# Instantiate a `VarInfo` with the example values.

0 commit comments

Comments
 (0)