Skip to content

Commit 9be931d

Browse files
committed
Added to the docs how to integrate with Gen.jl, Turing.jl, and Survival.jl.
1 parent 5f522ac commit 9be931d

File tree

9 files changed

+2135
-0
lines changed

9 files changed

+2135
-0
lines changed

docs/make.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ makedocs(;
111111
"commonrandom.md",
112112
"importance_skills.md",
113113
"hamiltonianmontecarlo.md",
114+
"gen/overview.md",
115+
"gen/distribution.md",
116+
"gen/generative_function.md",
117+
"gen/observation_likelihood.md",
118+
"gen/importance_mixture.md",
119+
"gen/hmc_paths.md",
120+
"gen/turing_dist.md",
121+
"gen/survival_snippet.md",
114122
],
115123
"API Reference" => [
116124
"contextinterface.md",

docs/src/gen/distribution.md

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# CompetingClocks as a Gen Distribution
2+
3+
This document shows how to treat an entire CompetingClocks trajectory as a single random choice in [Gen.jl](https://www.gen.dev/), enabling Bayesian inference over continuous-time discrete-event systems.
4+
5+
## Statistical Framework
6+
7+
### Path Likelihood for Generalized Semi-Markov Processes
8+
9+
Consider a continuous-time system over horizon ``[0, T]`` where:
10+
11+
- At time ``t``, a set of clocks ``\mathcal{K}_t`` are enabled
12+
- Each clock ``k`` has a survival function ``S_k(\tau) = 1 - F_k(\tau)`` and hazard ``h_k(\tau) = f_k(\tau)/S_k(\tau)``
13+
- The system produces a path ``x = \{(t_1, e_1), \ldots, (t_n, e_n)\}`` with ``0 < t_1 < \cdots < t_n \le T``
14+
15+
The path density with respect to Lebesgue measure on event times follows the standard GSMP form:
16+
17+
```math
18+
\log p(x \mid \theta) = \sum_{i=1}^n \log h_{e_i}(t_i \mid \mathcal{H}_{t_i^-}, \theta) - \int_0^T \sum_{k \in \mathcal{K}_s} h_k(s \mid \mathcal{H}_{s^-}, \theta) \, ds
19+
```
20+
21+
The first term sums log-hazards at each firing time. The second term—the integrated hazard over all enabled clocks—accounts for the probability of *not* firing before each event.
22+
23+
### CompetingClocks Handles the Integrals
24+
25+
When you construct a `SamplingContext` with `path_likelihood=true`, CompetingClocks tracks both terms automatically. After simulation:
26+
27+
```julia
28+
log_prob = pathloglikelihood(sampler, T)
29+
```
30+
31+
returns the exact log path likelihood, including the probability of no further events until time ``T``.
32+
33+
### Integration with Gen
34+
35+
The strategy is:
36+
37+
1. **Simulate** paths using CompetingClocks' `next`/`fire!`/`enable!` loop
38+
2. **Evaluate** ``\log p(x \mid \theta)`` via `pathloglikelihood`
39+
3. **Wrap** these as `Gen.random` and `Gen.logpdf` for a custom distribution
40+
41+
From Gen's perspective, the entire trajectory becomes a single continuous random choice whose density is delegated to CompetingClocks.
42+
43+
---
44+
45+
## Example: Density-Dependent Birth-Death Process
46+
47+
A simple linear birth-death process either explodes or dies out. For stable dynamics suitable for inference, we use a logistic birth rate:
48+
49+
```math
50+
\lambda(N) = \lambda_0 \cdot N \cdot \max\left(0, 1 - \frac{N}{K}\right)
51+
```
52+
53+
where ``K`` is the carrying capacity. This ensures the population fluctuates around ``K`` rather than diverging.
54+
55+
The complete working example is in `examples/gen_distribution.jl`.
56+
57+
### Basic Types
58+
59+
```julia
60+
using Random
61+
using Distributions
62+
using CompetingClocks
63+
64+
# Clock keys: (:birth, 0) for births, (:death, i) for individual i's death
65+
const ClockKey = Tuple{Symbol,Int}
66+
67+
# A single event in a trajectory
68+
struct BDEvent
69+
time::Float64
70+
key::ClockKey
71+
end
72+
73+
const EventPath = Vector{BDEvent}
74+
75+
# Population state
76+
mutable struct BDState
77+
population::Set{Int}
78+
next_id::Int
79+
end
80+
81+
BDState(N0::Int) = BDState(Set(1:N0), N0 + 1)
82+
```
83+
84+
### Simulation with Path Likelihood
85+
86+
```julia
87+
function simulate_bd(
88+
rng::AbstractRNG,
89+
t_max::Float64,
90+
λ_birth::Float64,
91+
K::Float64,
92+
death_shape::Float64,
93+
death_scale::Float64,
94+
N0::Int,
95+
)
96+
# Enable path likelihood tracking
97+
sampler = SamplingContext(ClockKey, Float64, rng; path_likelihood=true)
98+
state = BDState(N0)
99+
path = EventPath()
100+
101+
# Density-dependent birth rate
102+
function birth_rate(N::Int)
103+
N <= 0 && return 0.0
104+
return λ_birth * N * max(0.0, 1.0 - N / K)
105+
end
106+
107+
# Initialize clocks
108+
rate = birth_rate(length(state.population))
109+
if rate > 0
110+
enable!(sampler, (:birth, 0), Exponential(inv(rate)))
111+
end
112+
for i in state.population
113+
enable!(sampler, (:death, i), Gamma(death_shape, death_scale))
114+
end
115+
116+
# Simulation loop
117+
when, which = next(sampler)
118+
while !isnothing(which) && when <= t_max
119+
fire!(sampler, which, when)
120+
push!(path, BDEvent(when, which))
121+
122+
# Update state
123+
if which[1] == :birth
124+
new_id = state.next_id
125+
state.next_id += 1
126+
push!(state.population, new_id)
127+
enable!(sampler, (:death, new_id), Gamma(death_shape, death_scale))
128+
elseif which[1] == :death
129+
delete!(state.population, which[2])
130+
end
131+
132+
# Refresh birth clock with updated rate
133+
rate = birth_rate(length(state.population))
134+
if rate > 0
135+
enable!(sampler, (:birth, 0), Exponential(inv(rate)))
136+
end
137+
138+
when, which = next(sampler)
139+
end
140+
141+
return path, sampler
142+
end
143+
```
144+
145+
### Replay for Log-Likelihood Evaluation
146+
147+
To compute ``\log p(x \mid \theta)`` for an arbitrary path ``x``, we "replay" it through a fresh sampler:
148+
149+
```julia
150+
function bd_path_logpdf(
151+
path::EventPath,
152+
t_max::Float64,
153+
λ_birth::Float64,
154+
K::Float64,
155+
death_shape::Float64,
156+
death_scale::Float64,
157+
N0::Int,
158+
)::Float64
159+
# Validate path structure
160+
if any(ev -> ev.time < 0 || ev.time > t_max, path)
161+
return -Inf
162+
end
163+
for i in 2:length(path)
164+
if path[i].time <= path[i-1].time
165+
return -Inf
166+
end
167+
end
168+
169+
# Fresh sampler for replay (RNG unused)
170+
rng = Xoshiro(0)
171+
sampler = SamplingContext(ClockKey, Float64, rng; path_likelihood=true)
172+
state = BDState(N0)
173+
174+
# Same birth rate function
175+
function birth_rate(N::Int)
176+
N <= 0 && return 0.0
177+
return λ_birth * N * max(0.0, 1.0 - N / K)
178+
end
179+
180+
# Initialize
181+
rate = birth_rate(length(state.population))
182+
if rate > 0
183+
enable!(sampler, (:birth, 0), Exponential(inv(rate)))
184+
end
185+
for i in state.population
186+
enable!(sampler, (:death, i), Gamma(death_shape, death_scale))
187+
end
188+
189+
# Replay events
190+
for ev in path
191+
if !isenabled(sampler, ev.key)
192+
return -Inf # Invalid: firing a disabled clock
193+
end
194+
195+
fire!(sampler, ev.key, ev.time)
196+
197+
if ev.key[1] == :birth
198+
new_id = state.next_id
199+
state.next_id += 1
200+
push!(state.population, new_id)
201+
enable!(sampler, (:death, new_id), Gamma(death_shape, death_scale))
202+
elseif ev.key[1] == :death
203+
delete!(state.population, ev.key[2])
204+
end
205+
206+
rate = birth_rate(length(state.population))
207+
if rate > 0
208+
enable!(sampler, (:birth, 0), Exponential(inv(rate)))
209+
end
210+
end
211+
212+
return pathloglikelihood(sampler, t_max)
213+
end
214+
```
215+
216+
---
217+
218+
## Gen Distribution Wrapper
219+
220+
### Distribution Type
221+
222+
```julia
223+
using Gen
224+
225+
struct BDPathDist <: Gen.Distribution{EventPath} end
226+
const bd_path_dist = BDPathDist()
227+
```
228+
229+
### Required Methods
230+
231+
```julia
232+
function Gen.random(
233+
::BDPathDist,
234+
t_max::Float64,
235+
λ_birth::Float64,
236+
K::Float64,
237+
death_shape::Float64,
238+
death_scale::Float64,
239+
N0::Int,
240+
)::EventPath
241+
rng = Random.default_rng()
242+
path, _ = simulate_bd(rng, t_max, λ_birth, K, death_shape, death_scale, N0)
243+
return path
244+
end
245+
246+
function Gen.logpdf(
247+
::BDPathDist,
248+
path::EventPath,
249+
t_max::Float64,
250+
λ_birth::Float64,
251+
K::Float64,
252+
death_shape::Float64,
253+
death_scale::Float64,
254+
N0::Int,
255+
)::Float64
256+
return bd_path_logpdf(path, t_max, λ_birth, K, death_shape, death_scale, N0)
257+
end
258+
259+
Gen.is_discrete(::BDPathDist) = false
260+
Gen.has_output_grad(::BDPathDist) = false
261+
Gen.has_argument_grads(::BDPathDist) = (false, false, false, false, false, false)
262+
```
263+
264+
This is sufficient for importance sampling, SMC, and Metropolis-Hastings. For gradient-based inference (HMC/NUTS), implement `logpdf_grad`.
265+
266+
---
267+
268+
## Using the Distribution in a Gen Model
269+
270+
### Generative Model
271+
272+
```julia
273+
@gen function bd_model(t_max::Float64, K::Float64, N0::Int)
274+
# Priors on rate parameters
275+
λ_birth = @trace(gamma(2.0, 1.0), :λ_birth) # mean = 2
276+
mean_death = @trace(gamma(2.0, 2.0), :mean_death) # mean = 4
277+
278+
death_shape = 2.0
279+
death_scale = mean_death / death_shape
280+
281+
# Entire trajectory as one random choice
282+
path = @trace(
283+
bd_path_dist(t_max, λ_birth, K, death_shape, death_scale, N0),
284+
:path,
285+
)
286+
287+
return (λ_birth=λ_birth, mean_death=mean_death, path=path)
288+
end
289+
```
290+
291+
### Forward Simulation
292+
293+
```julia
294+
t_max = 10.0
295+
K = 50.0
296+
N0 = 10
297+
298+
trace = Gen.simulate(bd_model, (t_max, K, N0))
299+
retval = Gen.get_retval(trace)
300+
301+
println("λ_birth = ", retval.λ_birth)
302+
println("mean_death = ", retval.mean_death)
303+
println("Events = ", length(retval.path))
304+
println("Log-probability = ", Gen.get_score(trace))
305+
```
306+
307+
### Conditioning on Observed Data
308+
309+
Given an observed trajectory `obs_path`:
310+
311+
```julia
312+
obs = Gen.choicemap((:path, obs_path))
313+
trace, logw = Gen.generate(bd_model, (t_max, K, N0), obs)
314+
315+
# trace[:λ_birth] and trace[:mean_death] are sampled from the prior
316+
# logw is the importance weight
317+
```
318+
319+
For proper posterior inference, use Gen's inference library:
320+
321+
```julia
322+
# Importance sampling
323+
traces, weights, _ = Gen.importance_sampling(bd_model, (t_max, K, N0), obs, n_samples)
324+
325+
# MCMC
326+
trace, = Gen.importance_resampling(bd_model, (t_max, K, N0), obs, 100)
327+
for i in 1:1000
328+
trace, = Gen.mh(trace, Gen.select(:λ_birth, :mean_death))
329+
end
330+
```
331+
332+
---
333+
334+
## Generalizing to Other Models
335+
336+
The pattern applies to any CompetingClocks simulation:
337+
338+
1. **Define your simulation** using `SamplingContext(...; path_likelihood=true)`
339+
2. **Record the path** as a sequence of (time, event) pairs
340+
3. **Write a replay function** that:
341+
- Creates a fresh sampler with `path_likelihood=true`
342+
- Replays each event with `fire!(sampler, key, time)`
343+
- Returns `pathloglikelihood(sampler, T)`
344+
4. **Wrap as `Gen.Distribution`** with `random` and `logpdf`
345+
346+
The simulation logic (state transitions, clock distributions) is model-specific. The Gen integration remains identical.
347+
348+
### Examples to Adapt
349+
350+
- **SIR epidemics**: Infection and recovery clocks with population-dependent rates
351+
- **Queueing systems**: Arrival and service clocks with queue-length feedback
352+
- **Reliability models**: Component failure and repair with dependent hazards
353+
- **Chemical kinetics**: Reaction clocks with mass-action propensities
354+
355+
In each case, the path likelihood integral is handled automatically by CompetingClocks.

0 commit comments

Comments
 (0)