Skip to content

Commit 41f67fb

Browse files
committed
Try to condense implementation
1 parent d3a0cc0 commit 41f67fb

File tree

2 files changed

+30
-42
lines changed

2 files changed

+30
-42
lines changed

src/threadsafe.jl

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1515

1616
# Instead of updating the log probability of the underlying variables we
1717
# just update the array of log probabilities.
18-
function acclogp!(vi::ThreadSafeVarInfo, logp::Real)
18+
function acclogp!(vi::ThreadSafeVarInfo, logp)
1919
vi.logps[Threads.threadid()] += logp
2020
return getlogp(vi)
2121
end
@@ -27,37 +27,21 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
2727
# TODO: Make remaining methods thread-safe.
2828

2929
function resetlogp!(vi::ThreadSafeVarInfo)
30-
resetlogp!(vi.varinfo)
31-
z = zero(getlogp(vi))
32-
fill!(vi.logps, z)
33-
z
30+
fill!(vi.logps, zero(getlogp(vi)))
31+
return resetlogp!(vi.varinfo)
3432
end
35-
function setlogp!(vi::ThreadSafeVarInfo, logp::Real)
36-
if length(vi.logp) == 0
37-
push!(vi.logp, logp)
38-
else
39-
vi.logp[1] = logp
40-
end
41-
vi.lastidx[] = 1
42-
return logp
33+
function setlogp!(vi::ThreadSafeVarInfo, logp)
34+
fill!(vi.logps, zero(logp))
35+
return setlogp!(vi.varinfo, logp)
4336
end
4437

4538
get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo)
4639
increment_num_produce!(vi::ThreadSafeVarInfo) = increment_num_produce!(vi.varinfo)
4740
reset_num_produce!(vi::ThreadSafeVarInfo) = reset_num_produce!(vi.varinfo)
4841
set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n)
4942

50-
getall(vi::ThreadSafeVarInfo) = getall(vi.varinfo)
51-
setall!(vi::ThreadSafeVarInfo, val) = setall!(vi.varinfo, val)
52-
5343
syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo)
5444

55-
getmetadata(vi::ThreadSafeVarInfo, vn::VarName) = getmetadata(vi.varinfo, vn)
56-
getidx(vi::ThreadSafeVarInfo, vn::VarName) = getidx(vi.varinfo, vn)
57-
getrange(vi::ThreadSafeVarInfo, vn::VarName) = getrange(vi.varinfo, vn)
58-
getdist(vi::ThreadSafeVarInfo, vn::VarName) = getdist(vi.varinfo, vn)
59-
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
60-
6145
function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName)
6246
setgid!(vi.varinfo, gid, vn)
6347
end
@@ -66,18 +50,25 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
6650
keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo)
6751
haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn)
6852

69-
_getranges(vi::ThreadSafeVarInfo, idcs::NamedTuple) = _getranges(vi.varinfo, idcs)
70-
_getidcs(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = _getidcs(vi.varinfo, spl)
71-
_getidcs(vi::ThreadSafeVarInfo, s::Selector, space) = _getidcs(vi.varinfo, s, space)
72-
_getvns(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = _getvns(vi.varinfo, spl)
73-
_getvns(vi::ThreadSafeVarInfo, s::Selector, space) = _getvns(vi.varinfo, s, space)
74-
7553
link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl)
7654
invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl)
7755
islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl)
7856

79-
getindex(vi::ThreadSafeVarInfo, spl::Sampler) = getindex(vi.varinfo, spl)
80-
setindex!(vi::ThreadSafeVarInfo, val, spl::Sampler) = setindex!(vi.varinfo, val, spl)
57+
getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl)
58+
getindex(vi::ThreadSafeVarInfo, spl::SampleFromPrior) = getindex(vi.varinfo, spl)
59+
getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, spl)
60+
getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn)
61+
getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns)
62+
63+
function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler)
64+
setindex!(vi.varinfo, val, spl)
65+
end
66+
function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior)
67+
setindex!(vi.varinfo, val, spl)
68+
end
69+
function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform)
70+
setindex!(vi.varinfo, val, spl)
71+
end
8172

8273
function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler)
8374
return set_retained_vns_del_by_spl!(vi.varinfo, spl)
@@ -99,9 +90,6 @@ function push!(
9990
)
10091
push!(vi.varinfo, vn, r, dist, gidset)
10192
end
102-
function push_assert(vi::ThreadSafeVarInfo, vn::VarName, dist, gidset)
103-
return push_assert(vi.varinfo, vn, dist, gidset)
104-
end
10593

10694
function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
10795
return unset_flag!(vi.varinfo, vn, flag)

src/varinfo.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,23 +597,23 @@ Return the log of the joint probability of the observed data and parameters samp
597597
getlogp(vi::AbstractVarInfo) = vi.logp[]
598598

599599
"""
600-
setlogp!(vi::VarInfo, logp::Real)
600+
setlogp!(vi::VarInfo, logp)
601601
602602
Set the log of the joint probability of the observed data and parameters sampled in
603603
`vi` to `logp`.
604604
"""
605-
setlogp!(vi::AbstractVarInfo, logp::Real) = vi.logp[] = logp
605+
setlogp!(vi::VarInfo, logp) = vi.logp[] = logp
606606

607607
"""
608-
acclogp!(vi::VarInfo, logp::Real)
608+
acclogp!(vi::VarInfo, logp)
609609
610610
Add `logp` to the value of the log of the joint probability of the observed data and
611611
parameters sampled in `vi`.
612612
"""
613-
acclogp!(vi::AbstractVarInfo, logp::Real) = vi.logp[] += logp
613+
acclogp!(vi::VarInfo, logp) = vi.logp[] += logp
614614

615615
"""
616-
resetlogp!(vi::VarInfo)
616+
resetlogp!(vi::AbstractVarInfo)
617617
618618
Reset the value of the log of the joint probability of the observed data and parameters
619619
sampled in `vi` to 0.
@@ -625,24 +625,24 @@ resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi)))
625625
626626
Return the `num_produce` of `vi`.
627627
"""
628-
get_num_produce(vi::AbstractVarInfo) = vi.num_produce[]
628+
get_num_produce(vi::VarInfo) = vi.num_produce[]
629629

630630
"""
631631
set_num_produce!(vi::VarInfo, n::Int)
632632
633633
Set the `num_produce` field of `vi` to `n`.
634634
"""
635-
set_num_produce!(vi::AbstractVarInfo, n::Int) = vi.num_produce[] = n
635+
set_num_produce!(vi::VarInfo, n::Int) = vi.num_produce[] = n
636636

637637
"""
638638
increment_num_produce!(vi::VarInfo)
639639
640640
Add 1 to `num_produce` in `vi`.
641641
"""
642-
increment_num_produce!(vi::AbstractVarInfo) = vi.num_produce[] += 1
642+
increment_num_produce!(vi::VarInfo) = vi.num_produce[] += 1
643643

644644
"""
645-
reset_num_produce!(vi::VarInfo)
645+
reset_num_produce!(vi::AbstractVarInfo)
646646
647647
Reset the value of `num_produce` the log of the joint probability of the observed data
648648
and parameters sampled in `vi` to 0.

0 commit comments

Comments
 (0)