Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,8 @@ end
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)

function unset_flag!(
vi::ThreadSafeVarInfo, vn::VarName, flag::String, ignoreable::Bool=false
)
return unset_flag!(vi.varinfo, vn, flag, ignoreable)
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return unset_flag!(vi.varinfo, vn, flag)
end
function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
Expand Down
42 changes: 9 additions & 33 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ Construct an empty type unstable instance of `Metadata`.
function Metadata()
vals = Vector{Real}()
flags = Dict{String,BitVector}()
flags["del"] = BitVector()
flags["trans"] = BitVector()

return Metadata(
Expand Down Expand Up @@ -887,12 +886,7 @@ function set_flag!(md::Metadata, vn::VarName, flag::String)
end

function set_flag!(vnv::VarNamedVector, ::VarName, flag::String)
if flag == "del"
# The "del" flag is effectively always set for a VarNamedVector, so this is a no-op.
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
return vnv
throw(ErrorException("VarNamedVector does not support flags; Tried to set $(flag)."))
end

####
Expand Down Expand Up @@ -1710,7 +1704,7 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution)
[1:length(val)],
val,
[dist],
Dict{String,BitVector}("trans" => [false], "del" => [false]),
Dict{String,BitVector}("trans" => [false]),
)
vi = Accessors.@set vi.metadata[sym] = md
else
Expand Down Expand Up @@ -1744,7 +1738,6 @@ function Base.push!(meta::Metadata, vn, r, dist)
push!(meta.ranges, (l + 1):(l + n))
append!(meta.vals, val)
push!(meta.dists, dist)
push!(meta.flags["del"], false)
push!(meta.flags["trans"], false)
return meta
end
Expand All @@ -1770,42 +1763,25 @@ function is_flagged(metadata::Metadata, vn::VarName, flag::String)
return metadata.flags[flag][getidx(metadata, vn)]
end
function is_flagged(::VarNamedVector, ::VarName, flag::String)
if flag == "del"
return true
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
throw(ErrorException("VarNamedVector does not support flags; Tried to read $(flag)."))
end

# TODO(mhauru) The "ignorable" argument is a temporary hack while developing VarNamedVector,
# but still having to support the interface based on Metadata too
"""
unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false
unset_flag!(vi::VarInfo, vn::VarName, flag::String

Set `vn`'s value for `flag` to `false` in `vi`.

Setting some flags for some `VarInfo` types is not possible, and by default attempting to do
so will error. If `ignorable` is set to `true` then this will silently be ignored instead.
"""
function unset_flag!(vi::VarInfo, vn::VarName, flag::String, ignorable::Bool=false)
unset_flag!(getmetadata(vi, vn), vn, flag, ignorable)
function unset_flag!(vi::VarInfo, vn::VarName, flag::String)
unset_flag!(getmetadata(vi, vn), vn, flag)
return vi
end
function unset_flag!(metadata::Metadata, vn::VarName, flag::String, ignorable::Bool=false)
function unset_flag!(metadata::Metadata, vn::VarName, flag::String)
metadata.flags[flag][getidx(metadata, vn)] = false
return metadata
end

function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String, ignorable::Bool=false)
if ignorable
return vnv
end
if flag == "del"
throw(ErrorException("The \"del\" flag cannot be unset for VarNamedVector"))
else
throw(ErrorException("Flag $flag not valid for VarNamedVector"))
end
return vnv
function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String)
throw(ErrorException("VarNamedVector does not support flags; Tried to unset $(flag)."))
end

# TODO: Maybe rename or something?
Expand Down
17 changes: 6 additions & 11 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution)
r = rand(dist)
push!!(vi, vn, r, dist)
r
elseif DynamicPPL.is_flagged(vi, vn, "del")
DynamicPPL.unset_flag!(vi, vn, "del")
r = rand(dist)
vi[vn] = DynamicPPL.tovec(r)
r
else
vi[vn]
end
Expand Down Expand Up @@ -300,14 +295,14 @@ end

push!!(vi, vn_x, r, dist)

# del is set by default
@test !is_flagged(vi, vn_x, "del")
# trans is set by default
@test !is_flagged(vi, vn_x, "trans")

set_flag!(vi, vn_x, "del")
@test is_flagged(vi, vn_x, "del")
set_flag!(vi, vn_x, "trans")
@test is_flagged(vi, vn_x, "trans")

unset_flag!(vi, vn_x, "del")
@test !is_flagged(vi, vn_x, "del")
unset_flag!(vi, vn_x, "trans")
@test !is_flagged(vi, vn_x, "trans")
end
vi = VarInfo()
test_varinfo!(vi)
Expand Down
Loading