Skip to content

Commit 6686b36

Browse files
authored
Merge pull request #88 from adolgert/feature/sets-and-dicts
Feature/sets and dicts
2 parents 63345e8 + c215f5c commit 6686b36

20 files changed

+546
-97
lines changed

docs/src/algorithms.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ choose
1515
setindex!
1616
rand
1717
set_multiple!
18+
SetOfSets
19+
PrefixEnabled
1820
```

docs/src/reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ freeze
2929
haskey
3030
misscount
3131
misses
32+
enabled
3233
MultiSampler
3334
ChatReaction
3435
Petri

src/CompetingClocks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Documenter
33

44
const ContinuousTime = AbstractFloat
55

6+
include("setofsets.jl")
67
include("prefixsearch/binarytreeprefixsearch.jl")
78
include("prefixsearch/cumsumprefixsearch.jl")
89
include("prefixsearch/keyedprefixsearch.jl")

src/prefixsearch/cumsumprefixsearch.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ and, each time the Direct method samples, this evaluates the cumulative
1313
sum of the array.
1414
"""
1515
struct CumSumPrefixSearch{T<:Real}
16-
array::Vector{T}
16+
array::Vector{T}
1717
cumulant::Vector{T}
1818
end
1919

@@ -49,7 +49,7 @@ function Base.setindex!(ps::CumSumPrefixSearch{T}, value::T, index) where {T}
4949
ps.array[index] = value
5050
value
5151
end
52-
52+
Base.getindex(ps::CumSumPrefixSearch, index) = ps.array[index]
5353

5454
function Base.sum!(ps::CumSumPrefixSearch{T})::T where {T}
5555
cumsum!(ps.cumulant, ps.array)

src/prefixsearch/keyedprefixsearch.jl

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ re-enables the same set of clocks, this is the faster choice.
1212
"""
1313
struct KeyedKeepPrefixSearch{T,P} <: KeyedPrefixSearch
1414
# Map from clock name to index in propensity array.
15-
index::Dict{T, Int}
15+
index::Dict{T,Int}
1616
# Map from index in propensity array to clock name.
1717
key::Vector{T}
1818
prefix::P
@@ -37,6 +37,7 @@ end
3737

3838

3939
Base.length(kp::KeyedKeepPrefixSearch) = length(kp.index)
40+
key_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = T
4041
time_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = time_type(P)
4142

4243

@@ -54,9 +55,51 @@ end
5455

5556

5657
isenabled(kp::KeyedKeepPrefixSearch, clock) = (
57-
haskey(kp.index, clock) && kp.prefix[clock] > zero(time_type(kp))
58+
haskey(kp.index, clock) && kp.prefix[kp.index[clock]] > zero(time_type(kp))
5859
)
5960

61+
"""
62+
Construct a set that checks which values are zeroed out because this
63+
prefix sum doesn't mark what has been deleted. That's faster for small
64+
sets of keys but makes getting the set that's enabled more difficult.
65+
A hazard rate that is set to zero at enabling will give a funny count
66+
because it's `enabled!()` by the user but set to never fire.
67+
"""
68+
struct PrefixEnabled{K,P,KK} <: AbstractSet{K}
69+
prefix::P
70+
keys::KK
71+
end
72+
73+
# Implements the interface to return a set of enabled clock keys.
74+
function enabled(prefix::KeyedKeepPrefixSearch{T,P}) where {T,P}
75+
kk = keys(prefix.index)
76+
PrefixEnabled{T,typeof(prefix),typeof(kk)}(prefix, kk)
77+
end
78+
79+
function Base.iterate(nre::PrefixEnabled)
80+
res = iterate(nre.keys)
81+
res === nothing && return res
82+
while !isenabled(nre.prefix, res[1])
83+
res = iterate(nre.keys, res[2])
84+
res === nothing && return res
85+
end
86+
return res
87+
end
88+
89+
90+
function Base.iterate(nre::PrefixEnabled, state)
91+
res = iterate(nre.keys, state)
92+
res === nothing && return res
93+
while !isenabled(nre.prefix, res[1])
94+
res = iterate(nre.keys, res[2])
95+
res === nothing && return res
96+
end
97+
return res
98+
end
99+
100+
Base.length(nre::PrefixEnabled) = count(x -> isenabled(nre.prefix, x), nre.keys)
101+
Base.in(x, nre::PrefixEnabled) = isenabled(nre.prefix, x)
102+
Base.eltype(::Type{PrefixEnabled{C}}) where {C} = C
60103

61104
Base.delete!(kp::KeyedKeepPrefixSearch, clock) = kp.prefix[kp.index[clock]] = zero(time_type(kp))
62105
function Base.sum!(kp::KeyedKeepPrefixSearch)
@@ -72,7 +115,7 @@ end
72115

73116
function Random.rand(
74117
rng::AbstractRNG, d::Random.SamplerTrivial{KeyedKeepPrefixSearch{T,P}}
75-
) where {T,P}
118+
) where {T,P}
76119
total = sum!(d[])
77120
LocalTime = time_type(P)
78121
choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total)))
@@ -87,7 +130,7 @@ a large key space, this will use less memory.
87130
"""
88131
struct KeyedRemovalPrefixSearch{T,P} <: KeyedPrefixSearch
89132
# Map from clock name to index in propensity array.
90-
index::Dict{T, Int}
133+
index::Dict{T,Int}
91134
# Map from index in propensity array to clock name.
92135
key::Vector{T}
93136
free::Set{Int}
@@ -114,6 +157,7 @@ function Base.copy!(dst::KeyedRemovalPrefixSearch{T,P}, src::KeyedRemovalPrefixS
114157
end
115158

116159
Base.length(kp::KeyedRemovalPrefixSearch) = length(kp.index)
160+
key_type(kp::KeyedRemovalPrefixSearch{T,P}) where {T,P} = T
117161
time_type(kp::KeyedRemovalPrefixSearch{T,P}) where {T,P} = time_type(P)
118162

119163
function Base.setindex!(kp::KeyedRemovalPrefixSearch, val, clock)
@@ -134,10 +178,8 @@ function Base.setindex!(kp::KeyedRemovalPrefixSearch, val, clock)
134178
end
135179

136180

137-
isenabled(kp::KeyedRemovalPrefixSearch, clock) = (
138-
haskey(kp.index, clock) && kp.prefix[clock] > zero(time_type(kp))
139-
)
140-
181+
isenabled(kp::KeyedRemovalPrefixSearch, clock) = haskey(kp.index, clock)
182+
enabled(kp::KeyedRemovalPrefixSearch) = keys(kp.index)
141183

142184
function Base.getindex(kp::KeyedRemovalPrefixSearch, clock)
143185
if haskey(kp.index, clock)
@@ -170,7 +212,7 @@ end
170212

171213
function Random.rand(
172214
rng::AbstractRNG, d::Random.SamplerTrivial{KeyedRemovalPrefixSearch{T,P}}
173-
) where {T,P}
215+
) where {T,P}
174216
total = sum!(d[])
175217
LocalTime = time_type(P)
176218
choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total)))

src/sample/combinednr.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
using DataStructures: MutableBinaryMinHeap, extract_all!, update!
33

44
export sampling_space
5-
export CombinedNextReaction
5+
export CombinedNextReaction, enabled
66

77
"""
88
This function decides whether a particular distribution can be sampled faster
@@ -384,6 +384,55 @@ function Base.length(nr::CombinedNextReaction)
384384
return length(nr.transition_entry)
385385
end
386386

387+
388+
function isenabled(nr::CombinedNextReaction, clock)
389+
haskey(nr.transition_entry, clock) && nr.transition_entry[clock].heap_handle > 0
390+
end
391+
392+
393+
# A set of all enabled clock keys for a CombinedNextReaction method.
394+
# We make a custom Set implementation because the information is in the
395+
# CombinedNextReaction object, but it's spread across a Heap and a Dictionary.
396+
# This helper class should make it much more efficient to iterate the set.
397+
struct NextReactionEnabled{C,T,K} <: AbstractSet{C}
398+
nr::T
399+
keys::K
400+
end
401+
402+
_has_handle(nre::NextReactionEnabled, key) = nre.nr.transition_entry[key].heap_handle > 0
403+
404+
function Base.iterate(nre::NextReactionEnabled)
405+
res = iterate(nre.keys)
406+
res === nothing && return res
407+
while !_has_handle(nre, res[1])
408+
res = iterate(nre.keys, res[2])
409+
res === nothing && return res
410+
end
411+
return res
412+
end
413+
414+
415+
function Base.iterate(nre::NextReactionEnabled, state)
416+
res = iterate(nre.keys, state)
417+
res === nothing && return res
418+
while !_has_handle(nre, res[1])
419+
res = iterate(nre.keys, res[2])
420+
res === nothing && return res
421+
end
422+
return res
423+
end
424+
425+
Base.length(nre::NextReactionEnabled) = length(nre.nr.firing_queue)
426+
Base.in(x, nre::NextReactionEnabled) = isenabled(nre.nr, x)
427+
Base.eltype(::Type{NextReactionEnabled{C}}) where {C} = C
428+
429+
430+
function enabled(nr::CombinedNextReaction{K,T}) where {K,T}
431+
kks = keys(nr.transition_entry)
432+
NextReactionEnabled{K,typeof(nr),typeof(kks)}(nr, kks)
433+
end
434+
435+
387436
function Base.haskey(nr::CombinedNextReaction{K,T}, clock::K) where {K,T}
388437
return haskey(nr.transition_entry, clock) && nr.transition_entry[clock].heap_handle > 0
389438
end

src/sample/direct.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using Random: rand, AbstractRNG
22
using Distributions: Uniform, Exponential, rate
33

4-
export DirectCall, enable!, disable!, next
4+
export DirectCall, enable!, disable!, next, enabled
55

66

77
"""
@@ -126,13 +126,18 @@ function Base.getindex(dc::DirectCall{K,T,P}, clock::K) where {K,T,P}
126126
end
127127

128128
function Base.keys(dc::DirectCall)
129-
return collect(keys(dc.prefix_tree.index))
129+
return keys(dc.prefix_tree.index)
130130
end
131131

132132
function Base.length(dc::DirectCall)
133133
return length(dc.prefix_tree)
134134
end
135135

136+
# Implements the interface to return a set of enabled clock keys.
137+
enabled(dc::DirectCall{K,T,P}) where {K,T,P} = enabled(dc.prefix_tree)
138+
139+
isenabled(dc::DirectCall{K,T,P}, clock::K) where {K,T,P} = isenabled(dc.prefix_tree, clock)
140+
136141

137142
function steploglikelihood(dc::DirectCall, now, when, which)
138143
total = sum!(dc.prefix_tree)

src/sample/firstreaction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,4 @@ function next(fr::ChatReaction{K,T}, when::T, rng) where {K,T}
103103
end
104104

105105

106-
Base.haskey(fr::ChatReaction, clock) = isenabled(fr.enabled, clock)
106+
Base.haskey(fr::ChatReaction, clock) = haskey(fr.enabled, clock)

src/sample/firsttofire.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function next(propagator::FirstToFire{K,T}, when::T, rng::AbstractRNG) where {K,
4646
OrderedSample(nothing, typemax(T))
4747
end
4848
@debug("FirstToFire.next queue length ",
49-
length(propagator.firing_queue), " least ", least)
49+
length(propagator.firing_queue), " least ", least)
5050
(least.time, least.key)
5151
end
5252

@@ -101,3 +101,5 @@ end
101101

102102
Base.haskey(propagator::FirstToFire{K,T}, clock::K) where {K,T} = haskey(propagator.transition_entry, clock)
103103
Base.haskey(propagator::FirstToFire{K,T}, clock) where {K,T} = false
104+
105+
enabled(propagator::FirstToFire) = keys(propagator.transition_entry)

src/sample/interface.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Random: AbstractRNG
22
using Distributions: UnivariateDistribution
33
import Base: getindex, keys, length, keytype, haskey
44

5-
export SSA, enable!, disable!, next,
5+
export SSA, enable!, disable!, next,
66
getindex, keys, length, keytype
77

88
"""
@@ -39,7 +39,7 @@ function enable!(
3939
te::T, # enabling time
4040
when::T, # current simulation time
4141
rng::AbstractRNG
42-
) where {K,T}
42+
) where {K,T}
4343
@assert false
4444
end
4545

@@ -139,4 +139,11 @@ timetype(::SSA{K,T}) where {K,T} = T
139139
140140
Return a boolean.
141141
"""
142-
Base.haskey(sampler::SSA{K,T}, key) where {K,T}
142+
Base.haskey(sampler::SSA{K,T}, key) where {K,T}
143+
144+
"""
145+
enabled(sampler)
146+
147+
Returns a read-only set of currently-enabled clocks.
148+
"""
149+
enabled(sampler::SSA{K,T}) where {K,T}

0 commit comments

Comments
 (0)