Skip to content

Commit 6824060

Browse files
authored
Merge pull request #103 from adolgert/feature/sampler-spec
Feature/sampler spec
2 parents 7a3d331 + 1dae39b commit 6824060

16 files changed

+258
-89
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
1111
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
14+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1415
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1617
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
@@ -19,6 +20,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1920
Combinatorics = "^1.0"
2021
Distributions = "^0.25"
2122
Documenter = "^1.15"
23+
InteractiveUtils = "^1.10"
2224
Logging = "^1.0"
2325
Random = "^1.0"
2426
SpecialFunctions = "2.6.1"

docs/src/choosing_sampler.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ A sampler group is
5454
* A Symbol name for the sampler.
5555
* An inclusion function from `(ClockKey, Distribution)` to `Bool` that decides
5656
whether a given event belongs to this sampler.
57-
* An optional `sampler_spec` to say what kind of sampler this group should use.
57+
* An optional `method` to say what kind of sampler this group should use.
5858

5959
```julia
6060
builder = SamplerBuilder(KeyType, Float64)
61-
add_group!(builder, :sparky => (x,d) -> x[1] == :recover, sampler_spec=(:nextreaction,))
61+
add_group!(builder, :sparky => (x,d) -> x[1] == :recover, method=NextReaction())
6262
add_group!(builder, :forthright=>(x,d) -> x[1] == :infect)
6363
sampler = SamplingContext(builder, rng)
6464
```

docs/src/commonrandom.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using CompetingClocks
3131
example_clock = (3, 7) # We will use clock IDs that are a tuple of 2 integers.
3232
model = MakeModel()
3333
(Key, Time) = (typeof(example_clock), Float64)
34-
builder = SamplerBuilder(Key, Time; sampler_spec=:firsttofire, common_random=true)
34+
builder = SamplerBuilder(Key, Time; common_random=true)
3535
rng = Xoshiro(9469922)
3636
sampler = SamplingContext(builder, rng)
3737
for trial_idx in 1:100

docs/src/gene_expression.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,8 @@ function run_epochs(epoch_cnt, use_importance, rng)
221221
model = GeneExpression(params)
222222
builder = SamplerBuilder(
223223
Tuple{Symbol,Int}, Float64;
224-
sampler_spec=:firsttofire,
225-
trajectory_likelihood=true,
224+
method=FirstToFireMethod(),
225+
path_likelihood=true,
226226
likelihood_cnt=2,
227227
)
228228
sampler = SamplingContext(builder, rng)

docs/src/integration-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ at any point during a simulation. They are always calulated relative to the
118118
last event to `fire!()`, even if `next()` has been called.
119119
```julia
120120
builder = SamplerBuilder(KeyType, Float64;
121-
trajectory_likelihood=true)
121+
path_likelihood=true)
122122
sampler = SamplingContext(builder, rng)
123123
# After simulation, the log_prob is a Float64.
124124
log_prob = pathloglikelihood(sampler, end_time)

docs/src/simple_board.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ end
4848
# * `enable!(sampler, event ID, distribution)`
4949
# * `disable!(sampler, event ID)`
5050
#
51-
# There are a lot of samplers in CompetingClocks to choose from. This example uses `CombinedNextReaction`
52-
# algorithm, which has good performance for a variety of distributions. Samplers in CompetingClocks
51+
# There are a lot of samplers in CompetingClocks to choose from. Samplers in CompetingClocks
5352
# require two type parameters, a key type for clocks and the type used to represent time.
5453
# In this case, the clock key type fully represents an event, giving the ID of the individual,
5554
# where they start, and which direction they may move.
@@ -62,7 +61,7 @@ end
6261

6362
function run(event_count)
6463
rng = Xoshiro(2947223)
65-
builder = SamplerBuilder(ClockKey, Float64; sampler_spec=(:nextreaction,))
64+
builder = SamplerBuilder(ClockKey, Float64)
6665
sampler = SamplingContext(builder, rng)
6766
physical = PhysicalState(zeros(Int, 10, 10))
6867
sim = SimulationFSM(

src/CompetingClocks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include("sample/multiple_direct.jl")
2424
include("sample/combinednr.jl")
2525
include("variance/with_common_random.jl")
2626
include("sample/petri.jl")
27+
include("samplerspec.jl")
2728
include("sampler_builder.jl")
2829
include("context.jl")
2930

src/context.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ function SamplingContext(builder::SamplerBuilder, rng::R) where {R<:AbstractRNG}
3232
sampler = build_sampler(builder)
3333
if builder.likelihood_cnt > 1
3434
likelihood = PathLikelihoods{K,T}(builder.likelihood_cnt)
35-
elseif builder.trajectory_likelihood
35+
elseif builder.path_likelihood
3636
likelihood = TrajectoryWatcher{K,T}()
3737
else
3838
likelihood = nothing
@@ -56,6 +56,11 @@ function SamplingContext(builder::SamplerBuilder, rng::R) where {R<:AbstractRNG}
5656
end
5757

5858

59+
function SamplingContext(::Type{K}, ::Type{T}, rng::R; kwargs...) where {K,T,R<:AbstractRNG}
60+
return SamplingContext(SamplerBuilder(K, T; kwargs...), rng)
61+
end
62+
63+
5964
"""
6065
clone(sampling, rng)
6166
@@ -72,7 +77,7 @@ were making parallel RNGs with the `Random123` package, it might look like:
7277
master_seed = (0xd897a239, 0x77ff9238)
7378
rng = Philox4x((0, 0, 0, 0), master_seed)
7479
Key = Int64
75-
sampler = SamplingContext(SamplerBuilder(Key,Float64; trajectory_likelihood=true), rng)
80+
sampler = SamplingContext(SamplerBuilder(Key,Float64; path_likelihood=true), rng)
7681
observation_weight = zeros(Float64, particle_cnt)
7782
total_weight = zeros(Float64, particle_cnt)
7883
samplers = Vector{typeof(sampler)}(undef, particle_cnt)

src/sampler_builder.jl

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export SamplerBuilder, available_samplers, add_group!, build_sampler
1+
export SamplerBuilder, add_group!, build_sampler
22

33
has_steploglikelihood(::Type) = false
44
has_steploglikelihood(::Type{<:CombinedNextReaction}) = true
@@ -13,35 +13,34 @@ has_pathloglikelihood(::Type{MultipleDirect}) = true
1313
mutable struct SamplerBuilderGroup
1414
name::Symbol
1515
selector::Union{Function,Nothing}
16-
sampler_spec::Tuple{Symbol}
16+
method::Union{SamplerSpec,Nothing}
1717
instance::SSA
18-
# Constructor sets the instance to undefined.
19-
SamplerBuilderGroup(name::Symbol, selector, sampler_spec) = new(name, selector, sampler_spec)
18+
# Constructor sets `instance` member to undefined.
19+
SamplerBuilderGroup(name::Symbol, selector, method) = new(name, selector, method)
2020
end
2121

2222

2323
struct SamplerBuilder{K,T}
2424
clock_type::Type{K}
2525
time_type::Type{T}
2626
step_likelihood::Bool
27-
trajectory_likelihood::Bool
27+
path_likelihood::Bool
2828
debug::Bool
2929
recording::Bool
3030
common_random::Bool
3131
group::Vector{SamplerBuilderGroup}
32-
samplers::Dict{Tuple{Symbol,Vararg{Symbol}},Function}
3332
start_time::T
3433
likelihood_cnt::Int
3534
end
3635

3736
"""
3837
SamplerBuilder(::Type{K}, ::Type{T};
3938
step_likelihood=false,
40-
trajectory_likelihood=false,
39+
path_likelihood=false,
4140
debug=false,
4241
recording=false,
4342
common_random=false,
44-
sampler_spec=:none,
43+
method=nothing,
4544
start_time::T,
4645
likelihood_cnt::Int
4746
)
@@ -51,50 +50,48 @@ an initial sampler.
5150
5251
* `K` and `T` are the clock type and time type.
5352
* `step_likelihood` - whether you will call `steploglikelihood` before each `fire!`
54-
* `trajectory_likelihood` - whether you will call `pathloglikelihood`
53+
* `path_likelihood` - whether you will call `pathloglikelihood`
5554
at the end of a simulation run.
5655
* `debug` - Print log messages at the debug level.
5756
* `recording` - Store every enable and disable for later examination.
5857
* `common_random` - Use common random numbers during sampling.
59-
* `sampler_spec` - If you want a single, particular sampler, put its Symbol name here.
58+
* `method` - If you want a single, particular sampler, put its `SamplerSpec` here.
59+
It will create a group called `:all` that has this sampling method.
6060
* `start_time` - Sometimes a simulation shouldn't start at zero.
6161
* `likelihood_cnt` - The number of likelihoods to compute, corresponds to number of
62-
distributions in `enable!` calls. This turns on `trajectory_likelihood`.
62+
distributions in `enable!` calls. This turns on `path_likelihood`.
6363
6464
# Example
6565
6666
```julia
6767
builder = SamplerBuilder(Tuple,Float64)
68-
add_group!(builder, :sparky => (x,d) -> x[1] == :recover, sampler_spec=(:nextreaction,))
68+
add_group!(builder, :sparky => (x,d) -> x[1] == :recover, method=NextReaction())
6969
add_group!(builder, :forthright=>(x,d) -> x[1] == :infect)
7070
context = SamplingContext(builder, rng)
7171
```
7272
"""
7373
function SamplerBuilder(::Type{K}, ::Type{T};
7474
step_likelihood=false,
75-
trajectory_likelihood=false,
75+
path_likelihood=false,
7676
debug=false,
7777
recording=false,
7878
common_random=false,
79-
sampler_spec::Union{Symbol,Tuple{Symbol}}=(:none,), # Ask for specific sampler.
79+
method::Union{SamplerSpec,Nothing}=nothing, # Ask for specific sampler.
8080
start_time::T=zero(T),
8181
likelihood_cnt=1,
8282
) where {K,T}
8383
group = SamplerBuilderGroup[]
84-
avail = make_builder_dict()
85-
trajectory_likelihood = trajectory_likelihood || likelihood_cnt > 1
84+
path_likelihood = path_likelihood || likelihood_cnt > 1
8685
builder = SamplerBuilder(
87-
K, T, step_likelihood, trajectory_likelihood, debug, recording,
88-
common_random, group, avail, start_time, likelihood_cnt
86+
K, T, step_likelihood, path_likelihood, debug, recording,
87+
common_random, group, start_time, likelihood_cnt
8988
)
90-
if sampler_spec != (:none,)
91-
add_group!(builder, :all => (x, d) -> true; sampler_spec=sampler_spec)
89+
if !isnothing(method)
90+
add_group!(builder, :all => (x, d) -> true; method=method)
9291
end
9392
return builder
9493
end
9594

96-
available_samplers(builder::SamplerBuilder) = keys(builder.samplers)
97-
9895

9996
"""
10097
The `selector` defines the group of clocks that go to this sampler using
@@ -103,37 +100,17 @@ an inclusion rule, so it's a function from a clock key and distribution to a Boo
103100
function add_group!(
104101
builder::SamplerBuilder,
105102
selector::Union{Pair,Nothing}=nothing; # Which clocks use this sampler.
106-
sampler_spec::Union{Symbol,Tuple{Symbol}}=(:any,), # Ask for specific sampler.
103+
method::Union{SamplerSpec,Nothing}=nothing, # Ask for specific sampler.
107104
)
108-
sampler_spec = sampler_spec isa Symbol ? (sampler_spec,) : sampler_spec
109-
sampler_spec = sampler_spec == (:any,) ? (:firsttofire,) : sampler_spec
110-
if sampler_spec keys(builder.samplers)
111-
error("Looking for a sampler in this list: $(keys(builder.samplers))")
112-
end
113105
if length(builder.group) >= 1 && (builder.group[1].selector === nothing || selector === nothing)
114106
error("Need a selector on all samplers if there is more than one sampler.")
115107
end
116108
name = selector isa Pair ? selector.first : :all
117109
selector_func = selector isa Pair ? selector.second : selector
118-
push!(builder.group, SamplerBuilderGroup(name, selector_func, sampler_spec))
110+
push!(builder.group, SamplerBuilderGroup(name, selector_func, method))
119111
return nothing
120112
end
121113

122-
123-
function make_builder_dict()
124-
return Dict([
125-
(:nextreaction,) => (K, T) -> CombinedNextReaction{K,T}(),
126-
(:direct,) => (K, T) -> DirectCallExplicit(K, T, KeyedRemovalPrefixSearch, BinaryTreePrefixSearch),
127-
(:direct, :remove, :tree) => (K, T) -> DirectCallExplicit(K, T, KeyedRemovalPrefixSearch, BinaryTreePrefixSearch),
128-
(:direct, :keep, :tree) => (K, T) -> DirectCallExplicit(K, T, KeyedKeepPrefixSearch, BinaryTreePrefixSearch),
129-
(:direct, :remove, :array) => (K, T) -> DirectCallExplicit(K, T, KeyedRemovalPrefixSearch, CumSumPrefixSearch),
130-
(:direct, :keep, :array) => (K, T) -> DirectCallExplicit(K, T, KeyedKeepPrefixSearch, CumSumPrefixSearch),
131-
(:firstreaction,) => (K, T) -> FirstReaction{K,T}(),
132-
(:firsttofire,) => (K, T) -> FirstToFire{K,T}(),
133-
(:petri,) => (K, T) -> Petri{K,T}(),
134-
])
135-
end
136-
137114
"""
138115
See sampler.jl for the MultiSampler to understand how we're making a chooser.
139116
We would like to implement this:
@@ -149,17 +126,32 @@ function CompetingClocks.choose_sampler(
149126
return chooser.matcher(clock, distribution)
150127
end
151128

129+
function auto_select_method(builder::SamplerBuilder)
130+
# Auto-select a sampler method based on builder requirements
131+
if builder.path_likelihood
132+
return DirectMethod()
133+
elseif builder.step_likelihood
134+
return NextReactionMethod()
135+
else
136+
return FirstToFireMethod()
137+
end
138+
end
139+
152140
function build_sampler(builder::SamplerBuilder)
153-
isempty(builder.group) && error("Need to add_group! on the builder.")
154141
K = builder.clock_type
155142
T = builder.time_type
156-
if length(builder.group) == 1
157-
sampler = builder.samplers[builder.group[1].sampler_spec](K, T)
143+
if length(builder.group) == 0
144+
sampler = FirstToFireMethod()(K, T)
145+
matcher = nothing
146+
elseif length(builder.group) == 1
147+
method = isnothing(builder.group[1].method) ? auto_select_method(builder) : builder.group[1].method
148+
sampler = method(K, T)
158149
matcher = nothing
159150
else
160151
competes = builder.group
161152
for compete in competes
162-
compete.instance = builder.samplers[compete.sampler_spec](K, T)
153+
method = isnothing(compete.method) ? auto_select_method(builder) : compete.method
154+
compete.instance = method(K, T)
163155
end
164156
# Any direct method gets added to the others for combination.
165157
inclusion = Dict(samp.name => samp.selector for samp in competes)

0 commit comments

Comments
 (0)