Skip to content

Commit e2178c6

Browse files
authored
Fixes for SimpleVarInfo with Ref (#527)
* added missing getlogp impl for SimpleVarInfo with Ref * included SimpleVarInfo with Ref in the TestUtils.setup_varinfos * bump patch version * moved impls of acclogp!! and setlogp!! for SimpleVarInfo next to each other
1 parent 549d9b1 commit e2178c6

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.14"
3+
version = "0.23.15"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/simple_varinfo.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,11 @@ end
259259
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)
260260

261261
getlogp(vi::SimpleVarInfo) = vi.logp
262+
getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[]
263+
262264
setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp
263265
acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp
264266

265-
"""
266-
keys(vi::SimpleVarInfo)
267-
268-
Return an iterator of keys present in `vi`.
269-
"""
270-
Base.keys(vi::SimpleVarInfo) = keys(vi.values)
271-
Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values))
272-
273267
function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
274268
vi.logp[] = logp
275269
return vi
@@ -280,6 +274,14 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
280274
return vi
281275
end
282276

277+
"""
278+
keys(vi::SimpleVarInfo)
279+
280+
Return an iterator of keys present in `vi`.
281+
"""
282+
Base.keys(vi::SimpleVarInfo) = keys(vi.values)
283+
Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values))
284+
283285
function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo)
284286
if !(svi.transformation isa NoTransformation)
285287
print(io, "Transformed ")

src/test_utils.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,22 @@ Return a tuple of instances for different implementations of `AbstractVarInfo` w
4343
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
4444
"""
4545
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
46-
# <:VarInfo
46+
# VarInfo
4747
vi_untyped = VarInfo()
4848
model(vi_untyped)
4949
vi_typed = DynamicPPL.TypedVarInfo(vi_untyped)
50-
# <:SimpleVarInfo
50+
# SimpleVarInfo
5151
svi_typed = SimpleVarInfo(example_values)
5252
svi_untyped = SimpleVarInfo(OrderedDict())
5353

54+
# SimpleVarInfo{<:Any,<:Ref}
55+
svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed)))
56+
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))
57+
5458
lp = getlogp(vi_typed)
55-
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
59+
return map((
60+
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
61+
)) do vi
5662
# Set them all to the same values.
5763
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
5864
end

0 commit comments

Comments
 (0)