Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ Other functions such as `tilde_assume` and `assume` (and their `observe` counter
Note that this was effectively already the case in DynamicPPL 0.37 (where they were just wrappers around each other).
The separation of these functions was primarily implemented to avoid performing extra work where unneeded (e.g. to not calculate the log-likelihood when `PriorContext` was being used). This functionality has since been replaced with accumulators (see the 0.37 changelog for more details).

### Removal of the `"del"` flag

Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.

**Other changes**

### Reimplementation of functions using `InitContext`
Expand Down
6 changes: 2 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,8 @@ end
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)

function unset_flag!(
vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false
)
return unset_flag!(vi.varinfo, vn, flag, ignoreable)
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return unset_flag!(vi.varinfo, vn, flag)
end
function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
Expand Down
42 changes: 9 additions & 33 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ Construct an empty type unstable instance of `Metadata`.
function Metadata()
vals = Vector{Real}()
flags = Dict{String,BitVector}()
flags["del"] = BitVector()
flags["trans"] = BitVector()

return Metadata(
Expand Down Expand Up @@ -887,12 +886,7 @@ function set_flag!(md::Metadata, vn::VarName, flag::String)
end

function set_flag!(vnv::VarNamedVector, ::VarName, flag::String)
if flag == "del"
# The "del" flag is effectively always set for a VarNamedVector, so this is a no-op.
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
return vnv
throw(ErrorException("VarNamedVector does not support flags; Tried to set $(flag)."))
end

####
Expand Down Expand Up @@ -1710,7 +1704,7 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
[1:length(val)],
val,
[dist],
Dict{String,BitVector}("trans" => [false], "del" => [false]),
Dict{String,BitVector}("trans" => [false]),
)
vi = Accessors.@set vi.metadata[sym] = md
else
Expand Down Expand Up @@ -1744,7 +1738,6 @@ function Base.push!(meta::Metadata, vn, r, dist)
push!(meta.ranges, (l + 1):(l + n))
append!(meta.vals, val)
push!(meta.dists, dist)
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)
return meta
end
Expand All @@ -1770,42 +1763,25 @@ function is_flagged(metadata::Metadata, vn::VarName, flag::String)
return metadata.flags[flag][getidx(metadata, vn)]
end
function is_flagged(::VarNamedVector, ::VarName, flag::String)
if flag == "del"
return true
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
throw(ErrorException("VarNamedVector does not support flags; Tried to read $(flag)."))
end

# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector,
# but still having to support the interface based on Metadata too
"""
unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false
unset_flag!(vi::VarInfo, vn::VarName, flag::String

Set `vn`'s value for `flag` to `false` in `vi`.

Setting some flags for some `VarInfo` types is not possible, and by default attempting to do
so will error. If `ignorable` is set to `true` then this will silently be ignored instead.
"""
function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false)
unset_flag!(getmetadata(vi, vn), vn, flag, ignorable)
function unset_flag!(vi::VarInfo, vn::VarName, flag::String)
unset_flag!(getmetadata(vi, vn), vn, flag)
return vi
end
function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false)
function unset_flag!(metadata::Metadata, vn::VarName, flag::String)
metadata.flags[flag][getidx(metadata, vn)] = false
return metadata
end

function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false)
if ignorable
return vnv
end
if flag == "del"
throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector"))
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
return vnv
function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String)
throw(ErrorException("VarNamedVector does not support flags; Tried to unset $(flag)."))
end

# TODO: Maybe rename or something?
Expand Down
17 changes: 6 additions & 11 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution)
r = rand(dist)
push!!(vi, vn, r, dist)
r
elseif DynamicPPL.is_flagged(vi, vn, "del")
DynamicPPL.unset_flag!(vi, vn, "del")
r = rand(dist)
vi[vn] = DynamicPPL.tovec(r)
r
else
vi[vn]
end
Expand Down Expand Up @@ -300,14 +295,14 @@ end

push!!(vi, vn_x, r, dist)

# del is set by default
@test !is_flagged(vi, vn_x, "del")
# trans is set by default
@test !is_flagged(vi, vn_x, "trans")

set_flag!(vi, vn_x, "del")
@test is_flagged(vi, vn_x, "del")
set_flag!(vi, vn_x, "trans")
@test is_flagged(vi, vn_x, "trans")

unset_flag!(vi, vn_x, "del")
@test !is_flagged(vi, vn_x, "del")
unset_flag!(vi, vn_x, "trans")
@test !is_flagged(vi, vn_x, "trans")
end
vi = VarInfo()
test_varinfo!(vi)
Expand Down
Loading