Skip to content

Commit bb9e333

Browse files
committed
adds Hawkes process benchmark.
1 parent 04dca4c commit bb9e333

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
---
2+
title: Mutually exciting Hawkes process
3+
author: Guilherme Zagatti
4+
weave_options:
5+
fig_ext: ".png"
6+
---
7+
8+
```julia
9+
using JumpProcesses, Graphs, Plots, Statistics, BenchmarkTools
10+
```
11+
12+
# Model and example solutions
13+
14+
```julia
15+
function reset_history!(h; start_time = nothing)
16+
if start_time === nothing
17+
start_time = -Inf
18+
end
19+
@inbounds for i = 1:length(h)
20+
hi = h[i]
21+
ix = 0
22+
while ((ix + 1) <= length(hi)) && hi[ix+1] <= start_time
23+
ix += 1
24+
end
25+
end
26+
h[i] = ix == 0 ? eltype(h)[] : hi[1:ix]
27+
end
28+
nothing
29+
end
30+
31+
function hawkes_rate(i::Int, g, h)
32+
function rate(u, p, t)
33+
λ, α, β = p
34+
x = zero(typeof(t))
35+
for j in g[i]
36+
for _t in reverse(h[j])
37+
λij = α*exp(-β*(t - _t))
38+
if λij ≈ 0 break end
39+
x += λij
40+
end
41+
end
42+
return λ + x
43+
end
44+
return rate
45+
end
46+
47+
function hawkes_jump(i::Int, g, h)
48+
rate = hawkes_rate(i, g, h)
49+
lrate(u, p, t) = p[1]
50+
urate = rate
51+
function L(u, p, t)
52+
_lrate = lrate(u, p, t)
53+
_urate = urate(u, p, t)
54+
return _urate == _lrate ? typemax(t) : 1/(2*_urate)
55+
end
56+
function affect!(integrator)
57+
push!(h[i], integrator.t)
58+
integrator.u[i] += 1
59+
end
60+
return VariableRateJump(rate, affect!; lrate=lrate, urate=urate, L=L)
61+
end
62+
63+
function hawkes_jump(u, g, h)
64+
return [hawkes_jump(i, g, h) for i in 1:length(u)]
65+
end
66+
67+
function hawkes_problem(p, agg::QueueMethod; u=[0.], tspan=(0., 50.), save_positions=(false, true),
68+
g = [[1]], h = [[]])
69+
dprob = DiscreteProblem(u, tspan, p)
70+
jumps = hawkes_jump(u, g, h)
71+
jprob = JumpProblem(dprob, agg, jumps...;
72+
dep_graph=g, save_positions=save_positions)
73+
return jprob
74+
end
75+
76+
function f!(du, u, p, t)
77+
du .= 0
78+
nothing
79+
end
80+
81+
function hawkes_problem(p, agg; u=[0.], tspan=(0., 50.), save_positions=(false, true),
82+
g = [[1]], h = [[]])
83+
oprob = ODEProblem(f!, u, tspan, p)
84+
jumps = hawkes_jump(u, g, h)
85+
jprob = JumpProblem(oprob, agg, jumps...; save_positions=save_positions)
86+
return jprob
87+
end
88+
```
89+
90+
```julia
91+
methods = (Direct(), QueueMethod())
92+
shortlabels = [string(leg)[15:end-2] for leg in methods]
93+
94+
V = 10
95+
G = erdos_renyi(V, 0.2)
96+
g = [neighbors(G, i) for i in 1:nv(G)]
97+
98+
u = [0. for i in 1:nv(G)]
99+
p = (0.5, 0.1, 2.0)
100+
tspan = (0., 50.)
101+
h = [eltype(tspan)[] for _ in 1:nv(G)]
102+
103+
plots = []
104+
for (i, method) in enumerate(methods)
105+
jump_prob = hawkes_problem(p, method; u=u, tspan=tspan, g=g, h=h)
106+
reset_history!(h)
107+
if typeof(method) <: QueueMethod
108+
sol = solve(jump_prob, SSAStepper())
109+
else
110+
sol = solve(jump_prob, Tsit5())
111+
end
112+
push!(plots, plot(sol.t, sol[1:V, :]', title=shortlabels[i], legend=false, format=fmt))
113+
end
114+
plot(plots..., layout=(1, 2), format=fmt)
115+
```
116+
117+
# Benchmarking performance of the methods
118+
119+
```julia
120+
Vs = append!([1], 5:5:100)
121+
Gs = [erdos_renyi(V, 0.2) for V in Vs]
122+
nsims = 50
123+
benchmarks = Vector{Vector{BenchmarkTools.Trial}}()
124+
# initialize variables for BenchmarkTools
125+
_G = nothing
126+
_g = nothing
127+
_u = nothing
128+
_h = nothing
129+
_jump_prob = nothing
130+
_stepper = nothing
131+
_trial = nothing
132+
for method in methods
133+
println("Method: $(method)")
134+
push!(benchmarks, Vector{BenchmarkTools.Trial}())
135+
_bs = benchmarks[end]
136+
for (i, V) in enumerate(Vs)
137+
trial = try
138+
_G = Gs[i]
139+
_g = [neighbors(_G, i) for i in 1:nv(_G)]
140+
_u = [0. for i in 1:nv(_G)]
141+
_h = [eltype(tspan)[] for _ in 1:nv(_G)]
142+
_jump_prob = hawkes_problem(p, method; u=u, tspan=tspan, g=g, h=h)
143+
if typeof(method) <: QueueMethod
144+
_stepper = SSAStepper()
145+
else
146+
_stepper = Tsit5()
147+
end
148+
@benchmark(
149+
solve(_jump_prob, _stepper),
150+
setup=(reset_history!(_h)),
151+
samples=nsims,
152+
evals=1,
153+
seconds=60,
154+
)
155+
catch
156+
BenchmarkTools.Trial(BenchmarkTools.Parameters(samples=nsims, evals=1, seconds=60))
157+
end
158+
push!(_bs, _trial)
159+
if (V == 1 || V % 20 == 0) && (length(_bs[end]) > 0)
160+
println("\tV = $V, b = $(_bs[end])")
161+
end
162+
end
163+
time = Vector{Vector{Float64}}(undef, nsims)
164+
run_benchmark!(time, jump_prob, stepper)
165+
push!(benchmarks, time)
166+
end
167+
```
168+
169+
```julia
170+
plot(yscale=:log10, xlabel="V", ylabel="Time (ns)", legend_position=:outertopright)
171+
for (i, method) in enumerate(methods)
172+
_bs, _Vs = [], []
173+
for (i, b) in enumerate(benchmarks)
174+
if length(b) > 0
175+
push!(_bs, median(b.times))
176+
push!(_Vs, Vs[i])
177+
end
178+
end
179+
plot!(_Vs, _bs, label="$method")
180+
end
181+
title!("Simulation, $nsims samples: processes × median time")
182+
```
183+
184+
```julia, echo = false
185+
using SciMLBenchmarks
186+
SciMLBenchmarks.bench_footer(WEAVE_ARGS[:folder],WEAVE_ARGS[:file])
187+
```

0 commit comments

Comments
 (0)