Skip to content

Commit c1a66ad

Browse files
committed
modified the direct method to optionally accumulate log likelihood
1 parent 3e40ce4 commit c1a66ad

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/sample/direct.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@ sampler_noremove = DirectCall{K,T,typeof(keyed_prefix_tree)}(keyed_prefix_tree)
3535
"""
3636
struct DirectCall{K,T,P} <: SSA{K,T}
3737
prefix_tree::P
38+
now::Float64
39+
log_likelihood::Float64
40+
calculate_likelihood::Bool
3841
end
3942

4043

41-
function DirectCall{K,T}() where {K,T<:ContinuousTime}
44+
function DirectCall{K,T}(; trajectory=false) where {K,T<:ContinuousTime}
4245
prefix_tree = BinaryTreePrefixSearch{T}()
4346
keyed_prefix_tree = KeyedRemovalPrefixSearch{K,typeof(prefix_tree)}(prefix_tree)
44-
DirectCall{K,T,typeof(keyed_prefix_tree)}(keyed_prefix_tree)
47+
DirectCall{K,T,typeof(keyed_prefix_tree)}(keyed_prefix_tree, 0.0, 0.0, trajectory)
4548
end
4649

4750

@@ -64,7 +67,7 @@ If a particular clock had one rate before an event and it has another rate
6467
after the event, call `enable!` to update the rate.
6568
"""
6669
function enable!(dc::DirectCall{K,T,P}, clock::K, distribution::Exponential,
67-
te::T, when::T, rng::AbstractRNG) where {K,T,P}
70+
te::T, when::T, rng::AbstractRNG) where {K,T,P}
6871
dc.prefix_tree[clock] = rate(distribution)
6972
end
7073

@@ -85,6 +88,13 @@ function disable!(dc::DirectCall{K,T,P}, clock::K, when::T) where {K,T,P}
8588
delete!(dc.prefix_tree, clock)
8689
end
8790

91+
function fire!(dc::DirectCall{K,T,P}, clock::K, when::T) where {K,T,P}
92+
if dc.trajectory
93+
dc.log_likelihood += steploglikelihood(dc, dc.now, when, clock)
94+
end
95+
disable!(dc, clock, when)
96+
dc.now = when
97+
end
8898

8999
"""
90100
next(dc::DirectCall, when::TimeType, rng::AbstractRNG)

0 commit comments

Comments
 (0)