@@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
15
15
16
16
# Instead of updating the log probability of the underlying variables we
17
17
# just update the array of log probabilities.
18
- function acclogp! (vi:: ThreadSafeVarInfo , logp:: Real )
18
+ function acclogp! (vi:: ThreadSafeVarInfo , logp)
19
19
vi. logps[Threads. threadid ()] += logp
20
20
return getlogp (vi)
21
21
end
@@ -27,37 +27,21 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
27
27
# TODO : Make remaining methods thread-safe.
28
28
29
29
function resetlogp! (vi:: ThreadSafeVarInfo )
30
- resetlogp! (vi. varinfo)
31
- z = zero (getlogp (vi))
32
- fill! (vi. logps, z)
33
- z
30
+ fill! (vi. logps, zero (getlogp (vi)))
31
+ return resetlogp! (vi. varinfo)
34
32
end
35
- function setlogp! (vi:: ThreadSafeVarInfo , logp:: Real )
36
- if length (vi. logp) == 0
37
- push! (vi. logp, logp)
38
- else
39
- vi. logp[1 ] = logp
40
- end
41
- vi. lastidx[] = 1
42
- return logp
33
+ function setlogp! (vi:: ThreadSafeVarInfo , logp)
34
+ fill! (vi. logps, zero (logp))
35
+ return setlogp! (vi. varinfo, logp)
43
36
end
44
37
45
38
get_num_produce (vi:: ThreadSafeVarInfo ) = get_num_produce (vi. varinfo)
46
39
increment_num_produce! (vi:: ThreadSafeVarInfo ) = increment_num_produce! (vi. varinfo)
47
40
reset_num_produce! (vi:: ThreadSafeVarInfo ) = reset_num_produce! (vi. varinfo)
48
41
set_num_produce! (vi:: ThreadSafeVarInfo , n:: Int ) = set_num_produce! (vi. varinfo, n)
49
42
50
- getall (vi:: ThreadSafeVarInfo ) = getall (vi. varinfo)
51
- setall! (vi:: ThreadSafeVarInfo , val) = setall! (vi. varinfo, val)
52
-
53
43
syms (vi:: ThreadSafeVarInfo ) = syms (vi. varinfo)
54
44
55
- getmetadata (vi:: ThreadSafeVarInfo , vn:: VarName ) = getmetadata (vi. varinfo, vn)
56
- getidx (vi:: ThreadSafeVarInfo , vn:: VarName ) = getidx (vi. varinfo, vn)
57
- getrange (vi:: ThreadSafeVarInfo , vn:: VarName ) = getrange (vi. varinfo, vn)
58
- getdist (vi:: ThreadSafeVarInfo , vn:: VarName ) = getdist (vi. varinfo, vn)
59
- getval (vi:: ThreadSafeVarInfo , vn:: VarName ) = getval (vi. varinfo, vn)
60
-
61
45
function setgid! (vi:: ThreadSafeVarInfo , gid:: Selector , vn:: VarName )
62
46
setgid! (vi. varinfo, gid, vn)
63
47
end
@@ -66,18 +50,25 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn)
66
50
keys (vi:: ThreadSafeVarInfo ) = keys (vi. varinfo)
67
51
haskey (vi:: ThreadSafeVarInfo , vn:: VarName ) = haskey (vi. varinfo, vn)
68
52
69
- _getranges (vi:: ThreadSafeVarInfo , idcs:: NamedTuple ) = _getranges (vi. varinfo, idcs)
70
- _getidcs (vi:: ThreadSafeVarInfo , spl:: SampleFromPrior ) = _getidcs (vi. varinfo, spl)
71
- _getidcs (vi:: ThreadSafeVarInfo , s:: Selector , space) = _getidcs (vi. varinfo, s, space)
72
- _getvns (vi:: ThreadSafeVarInfo , spl:: SampleFromPrior ) = _getvns (vi. varinfo, spl)
73
- _getvns (vi:: ThreadSafeVarInfo , s:: Selector , space) = _getvns (vi. varinfo, s, space)
74
-
75
53
link! (vi:: ThreadSafeVarInfo , spl:: AbstractSampler ) = link! (vi. varinfo, spl)
76
54
invlink! (vi:: ThreadSafeVarInfo , spl:: AbstractSampler ) = invlink! (vi. varinfo, spl)
77
55
islinked (vi:: ThreadSafeVarInfo , spl:: AbstractSampler ) = islinked (vi. varinfo, spl)
78
56
79
- getindex (vi:: ThreadSafeVarInfo , spl:: Sampler ) = getindex (vi. varinfo, spl)
80
- setindex! (vi:: ThreadSafeVarInfo , val, spl:: Sampler ) = setindex! (vi. varinfo, val, spl)
57
+ getindex (vi:: ThreadSafeVarInfo , spl:: AbstractSampler ) = getindex (vi. varinfo, spl)
58
+ getindex (vi:: ThreadSafeVarInfo , spl:: SampleFromPrior ) = getindex (vi. varinfo, spl)
59
+ getindex (vi:: ThreadSafeVarInfo , spl:: SampleFromUniform ) = getindex (vi. varinfo, spl)
60
+ getindex (vi:: ThreadSafeVarInfo , vn:: VarName ) = getindex (vi. varinfo, vn)
61
+ getindex (vi:: ThreadSafeVarInfo , vns:: Vector{<:VarName} ) = getindex (vi. varinfo, vns)
62
+
63
+ function setindex! (vi:: ThreadSafeVarInfo , val, spl:: AbstractSampler )
64
+ setindex! (vi. varinfo, val, spl)
65
+ end
66
+ function setindex! (vi:: ThreadSafeVarInfo , val, spl:: SampleFromPrior )
67
+ setindex! (vi. varinfo, val, spl)
68
+ end
69
+ function setindex! (vi:: ThreadSafeVarInfo , val, spl:: SampleFromUniform )
70
+ setindex! (vi. varinfo, val, spl)
71
+ end
81
72
82
73
function set_retained_vns_del_by_spl! (vi:: ThreadSafeVarInfo , spl:: Sampler )
83
74
return set_retained_vns_del_by_spl! (vi. varinfo, spl)
@@ -99,9 +90,6 @@ function push!(
99
90
)
100
91
push! (vi. varinfo, vn, r, dist, gidset)
101
92
end
102
- function push_assert (vi:: ThreadSafeVarInfo , vn:: VarName , dist, gidset)
103
- return push_assert (vi. varinfo, vn, dist, gidset)
104
- end
105
93
106
94
function unset_flag! (vi:: ThreadSafeVarInfo , vn:: VarName , flag:: String )
107
95
return unset_flag! (vi. varinfo, vn, flag)
0 commit comments