|
| 1 | +# CompetingClocks as a Gen Distribution |
| 2 | + |
| 3 | +This document shows how to treat an entire CompetingClocks trajectory as a single random choice in [Gen.jl](https://www.gen.dev/), enabling Bayesian inference over continuous-time discrete-event systems. |
| 4 | + |
| 5 | +## Statistical Framework |
| 6 | + |
| 7 | +### Path Likelihood for Generalized Semi-Markov Processes |
| 8 | + |
| 9 | +Consider a continuous-time system over horizon ``[0, T]`` where: |
| 10 | + |
| 11 | +- At time ``t``, a set of clocks ``\mathcal{K}_t`` are enabled |
| 12 | +- Each clock ``k`` has a survival function ``S_k(\tau) = 1 - F_k(\tau)`` and hazard ``h_k(\tau) = f_k(\tau)/S_k(\tau)`` |
| 13 | +- The system produces a path ``x = \{(t_1, e_1), \ldots, (t_n, e_n)\}`` with ``0 < t_1 < \cdots < t_n \le T`` |
| 14 | + |
| 15 | +The path density with respect to Lebesgue measure on event times follows the standard GSMP form: |
| 16 | + |
| 17 | +```math |
| 18 | +\log p(x \mid \theta) = \sum_{i=1}^n \log h_{e_i}(t_i \mid \mathcal{H}_{t_i^-}, \theta) - \int_0^T \sum_{k \in \mathcal{K}_s} h_k(s \mid \mathcal{H}_{s^-}, \theta) \, ds |
| 19 | +``` |
| 20 | + |
| 21 | +The first term sums log-hazards at each firing time. The second term—the integrated hazard over all enabled clocks—accounts for the probability of *not* firing before each event. |
| 22 | + |
| 23 | +### CompetingClocks Handles the Integrals |
| 24 | + |
| 25 | +When you construct a `SamplingContext` with `path_likelihood=true`, CompetingClocks tracks both terms automatically. After simulation: |
| 26 | + |
| 27 | +```julia |
| 28 | +log_prob = pathloglikelihood(sampler, T) |
| 29 | +``` |
| 30 | + |
| 31 | +returns the exact log path likelihood, including the probability of no further events until time ``T``. |
| 32 | + |
| 33 | +### Integration with Gen |
| 34 | + |
| 35 | +The strategy is: |
| 36 | + |
| 37 | +1. **Simulate** paths using CompetingClocks' `next`/`fire!`/`enable!` loop |
| 38 | +2. **Evaluate** ``\log p(x \mid \theta)`` via `pathloglikelihood` |
| 39 | +3. **Wrap** these as `Gen.random` and `Gen.logpdf` for a custom distribution |
| 40 | + |
| 41 | +From Gen's perspective, the entire trajectory becomes a single continuous random choice whose density is delegated to CompetingClocks. |
| 42 | + |
| 43 | +--- |
| 44 | + |
| 45 | +## Example: Density-Dependent Birth-Death Process |
| 46 | + |
| 47 | +A simple linear birth-death process either explodes or dies out. For stable dynamics suitable for inference, we use a logistic birth rate: |
| 48 | + |
| 49 | +```math |
| 50 | +\lambda(N) = \lambda_0 \cdot N \cdot \max\left(0, 1 - \frac{N}{K}\right) |
| 51 | +``` |
| 52 | + |
| 53 | +where ``K`` is the carrying capacity. This ensures the population fluctuates around ``K`` rather than diverging. |
| 54 | + |
| 55 | +The complete working example is in `examples/gen_distribution.jl`. |
| 56 | + |
| 57 | +### Basic Types |
| 58 | + |
| 59 | +```julia |
| 60 | +using Random |
| 61 | +using Distributions |
| 62 | +using CompetingClocks |
| 63 | + |
| 64 | +# Clock keys: (:birth, 0) for births, (:death, i) for individual i's death |
| 65 | +const ClockKey = Tuple{Symbol,Int} |
| 66 | + |
| 67 | +# A single event in a trajectory |
| 68 | +struct BDEvent |
| 69 | + time::Float64 |
| 70 | + key::ClockKey |
| 71 | +end |
| 72 | + |
| 73 | +const EventPath = Vector{BDEvent} |
| 74 | + |
| 75 | +# Population state |
| 76 | +mutable struct BDState |
| 77 | + population::Set{Int} |
| 78 | + next_id::Int |
| 79 | +end |
| 80 | + |
| 81 | +BDState(N0::Int) = BDState(Set(1:N0), N0 + 1) |
| 82 | +``` |
| 83 | + |
| 84 | +### Simulation with Path Likelihood |
| 85 | + |
| 86 | +```julia |
| 87 | +function simulate_bd( |
| 88 | + rng::AbstractRNG, |
| 89 | + t_max::Float64, |
| 90 | + λ_birth::Float64, |
| 91 | + K::Float64, |
| 92 | + death_shape::Float64, |
| 93 | + death_scale::Float64, |
| 94 | + N0::Int, |
| 95 | +) |
| 96 | + # Enable path likelihood tracking |
| 97 | + sampler = SamplingContext(ClockKey, Float64, rng; path_likelihood=true) |
| 98 | + state = BDState(N0) |
| 99 | + path = EventPath() |
| 100 | + |
| 101 | + # Density-dependent birth rate |
| 102 | + function birth_rate(N::Int) |
| 103 | + N <= 0 && return 0.0 |
| 104 | + return λ_birth * N * max(0.0, 1.0 - N / K) |
| 105 | + end |
| 106 | + |
| 107 | + # Initialize clocks |
| 108 | + rate = birth_rate(length(state.population)) |
| 109 | + if rate > 0 |
| 110 | + enable!(sampler, (:birth, 0), Exponential(inv(rate))) |
| 111 | + end |
| 112 | + for i in state.population |
| 113 | + enable!(sampler, (:death, i), Gamma(death_shape, death_scale)) |
| 114 | + end |
| 115 | + |
| 116 | + # Simulation loop |
| 117 | + when, which = next(sampler) |
| 118 | + while !isnothing(which) && when <= t_max |
| 119 | + fire!(sampler, which, when) |
| 120 | + push!(path, BDEvent(when, which)) |
| 121 | + |
| 122 | + # Update state |
| 123 | + if which[1] == :birth |
| 124 | + new_id = state.next_id |
| 125 | + state.next_id += 1 |
| 126 | + push!(state.population, new_id) |
| 127 | + enable!(sampler, (:death, new_id), Gamma(death_shape, death_scale)) |
| 128 | + elseif which[1] == :death |
| 129 | + delete!(state.population, which[2]) |
| 130 | + end |
| 131 | + |
| 132 | + # Refresh birth clock with updated rate |
| 133 | + rate = birth_rate(length(state.population)) |
| 134 | + if rate > 0 |
| 135 | + enable!(sampler, (:birth, 0), Exponential(inv(rate))) |
| 136 | + end |
| 137 | + |
| 138 | + when, which = next(sampler) |
| 139 | + end |
| 140 | + |
| 141 | + return path, sampler |
| 142 | +end |
| 143 | +``` |
| 144 | + |
| 145 | +### Replay for Log-Likelihood Evaluation |
| 146 | + |
| 147 | +To compute ``\log p(x \mid \theta)`` for an arbitrary path ``x``, we "replay" it through a fresh sampler: |
| 148 | + |
| 149 | +```julia |
| 150 | +function bd_path_logpdf( |
| 151 | + path::EventPath, |
| 152 | + t_max::Float64, |
| 153 | + λ_birth::Float64, |
| 154 | + K::Float64, |
| 155 | + death_shape::Float64, |
| 156 | + death_scale::Float64, |
| 157 | + N0::Int, |
| 158 | +)::Float64 |
| 159 | + # Validate path structure |
| 160 | + if any(ev -> ev.time < 0 || ev.time > t_max, path) |
| 161 | + return -Inf |
| 162 | + end |
| 163 | + for i in 2:length(path) |
| 164 | + if path[i].time <= path[i-1].time |
| 165 | + return -Inf |
| 166 | + end |
| 167 | + end |
| 168 | + |
| 169 | + # Fresh sampler for replay (RNG unused) |
| 170 | + rng = Xoshiro(0) |
| 171 | + sampler = SamplingContext(ClockKey, Float64, rng; path_likelihood=true) |
| 172 | + state = BDState(N0) |
| 173 | + |
| 174 | + # Same birth rate function |
| 175 | + function birth_rate(N::Int) |
| 176 | + N <= 0 && return 0.0 |
| 177 | + return λ_birth * N * max(0.0, 1.0 - N / K) |
| 178 | + end |
| 179 | + |
| 180 | + # Initialize |
| 181 | + rate = birth_rate(length(state.population)) |
| 182 | + if rate > 0 |
| 183 | + enable!(sampler, (:birth, 0), Exponential(inv(rate))) |
| 184 | + end |
| 185 | + for i in state.population |
| 186 | + enable!(sampler, (:death, i), Gamma(death_shape, death_scale)) |
| 187 | + end |
| 188 | + |
| 189 | + # Replay events |
| 190 | + for ev in path |
| 191 | + if !isenabled(sampler, ev.key) |
| 192 | + return -Inf # Invalid: firing a disabled clock |
| 193 | + end |
| 194 | + |
| 195 | + fire!(sampler, ev.key, ev.time) |
| 196 | + |
| 197 | + if ev.key[1] == :birth |
| 198 | + new_id = state.next_id |
| 199 | + state.next_id += 1 |
| 200 | + push!(state.population, new_id) |
| 201 | + enable!(sampler, (:death, new_id), Gamma(death_shape, death_scale)) |
| 202 | + elseif ev.key[1] == :death |
| 203 | + delete!(state.population, ev.key[2]) |
| 204 | + end |
| 205 | + |
| 206 | + rate = birth_rate(length(state.population)) |
| 207 | + if rate > 0 |
| 208 | + enable!(sampler, (:birth, 0), Exponential(inv(rate))) |
| 209 | + end |
| 210 | + end |
| 211 | + |
| 212 | + return pathloglikelihood(sampler, t_max) |
| 213 | +end |
| 214 | +``` |
| 215 | + |
| 216 | +--- |
| 217 | + |
| 218 | +## Gen Distribution Wrapper |
| 219 | + |
| 220 | +### Distribution Type |
| 221 | + |
| 222 | +```julia |
| 223 | +using Gen |
| 224 | + |
| 225 | +struct BDPathDist <: Gen.Distribution{EventPath} end |
| 226 | +const bd_path_dist = BDPathDist() |
| 227 | +``` |
| 228 | + |
| 229 | +### Required Methods |
| 230 | + |
| 231 | +```julia |
| 232 | +function Gen.random( |
| 233 | + ::BDPathDist, |
| 234 | + t_max::Float64, |
| 235 | + λ_birth::Float64, |
| 236 | + K::Float64, |
| 237 | + death_shape::Float64, |
| 238 | + death_scale::Float64, |
| 239 | + N0::Int, |
| 240 | +)::EventPath |
| 241 | + rng = Random.default_rng() |
| 242 | + path, _ = simulate_bd(rng, t_max, λ_birth, K, death_shape, death_scale, N0) |
| 243 | + return path |
| 244 | +end |
| 245 | + |
| 246 | +function Gen.logpdf( |
| 247 | + ::BDPathDist, |
| 248 | + path::EventPath, |
| 249 | + t_max::Float64, |
| 250 | + λ_birth::Float64, |
| 251 | + K::Float64, |
| 252 | + death_shape::Float64, |
| 253 | + death_scale::Float64, |
| 254 | + N0::Int, |
| 255 | +)::Float64 |
| 256 | + return bd_path_logpdf(path, t_max, λ_birth, K, death_shape, death_scale, N0) |
| 257 | +end |
| 258 | + |
| 259 | +Gen.is_discrete(::BDPathDist) = false |
| 260 | +Gen.has_output_grad(::BDPathDist) = false |
| 261 | +Gen.has_argument_grads(::BDPathDist) = (false, false, false, false, false, false) |
| 262 | +``` |
| 263 | + |
| 264 | +This is sufficient for importance sampling, SMC, and Metropolis-Hastings. For gradient-based inference (HMC/NUTS), implement `logpdf_grad`. |
| 265 | + |
| 266 | +--- |
| 267 | + |
| 268 | +## Using the Distribution in a Gen Model |
| 269 | + |
| 270 | +### Generative Model |
| 271 | + |
| 272 | +```julia |
| 273 | +@gen function bd_model(t_max::Float64, K::Float64, N0::Int) |
| 274 | + # Priors on rate parameters |
| 275 | + λ_birth = @trace(gamma(2.0, 1.0), :λ_birth) # mean = 2 |
| 276 | + mean_death = @trace(gamma(2.0, 2.0), :mean_death) # mean = 4 |
| 277 | + |
| 278 | + death_shape = 2.0 |
| 279 | + death_scale = mean_death / death_shape |
| 280 | + |
| 281 | + # Entire trajectory as one random choice |
| 282 | + path = @trace( |
| 283 | + bd_path_dist(t_max, λ_birth, K, death_shape, death_scale, N0), |
| 284 | + :path, |
| 285 | + ) |
| 286 | + |
| 287 | + return (λ_birth=λ_birth, mean_death=mean_death, path=path) |
| 288 | +end |
| 289 | +``` |
| 290 | + |
| 291 | +### Forward Simulation |
| 292 | + |
| 293 | +```julia |
| 294 | +t_max = 10.0 |
| 295 | +K = 50.0 |
| 296 | +N0 = 10 |
| 297 | + |
| 298 | +trace = Gen.simulate(bd_model, (t_max, K, N0)) |
| 299 | +retval = Gen.get_retval(trace) |
| 300 | + |
| 301 | +println("λ_birth = ", retval.λ_birth) |
| 302 | +println("mean_death = ", retval.mean_death) |
| 303 | +println("Events = ", length(retval.path)) |
| 304 | +println("Log-probability = ", Gen.get_score(trace)) |
| 305 | +``` |
| 306 | + |
| 307 | +### Conditioning on Observed Data |
| 308 | + |
| 309 | +Given an observed trajectory `obs_path`: |
| 310 | + |
| 311 | +```julia |
| 312 | +obs = Gen.choicemap((:path, obs_path)) |
| 313 | +trace, logw = Gen.generate(bd_model, (t_max, K, N0), obs) |
| 314 | + |
| 315 | +# trace[:λ_birth] and trace[:mean_death] are sampled from the prior |
| 316 | +# logw is the importance weight |
| 317 | +``` |
| 318 | + |
| 319 | +For proper posterior inference, use Gen's inference library: |
| 320 | + |
| 321 | +```julia |
| 322 | +# Importance sampling |
| 323 | +traces, weights, _ = Gen.importance_sampling(bd_model, (t_max, K, N0), obs, n_samples) |
| 324 | + |
| 325 | +# MCMC |
| 326 | +trace, = Gen.importance_resampling(bd_model, (t_max, K, N0), obs, 100) |
| 327 | +for i in 1:1000 |
| 328 | + trace, = Gen.mh(trace, Gen.select(:λ_birth, :mean_death)) |
| 329 | +end |
| 330 | +``` |
| 331 | + |
| 332 | +--- |
| 333 | + |
| 334 | +## Generalizing to Other Models |
| 335 | + |
| 336 | +The pattern applies to any CompetingClocks simulation: |
| 337 | + |
| 338 | +1. **Define your simulation** using `SamplingContext(...; path_likelihood=true)` |
| 339 | +2. **Record the path** as a sequence of (time, event) pairs |
| 340 | +3. **Write a replay function** that: |
| 341 | + - Creates a fresh sampler with `path_likelihood=true` |
| 342 | + - Replays each event with `fire!(sampler, key, time)` |
| 343 | + - Returns `pathloglikelihood(sampler, T)` |
| 344 | +4. **Wrap as `Gen.Distribution`** with `random` and `logpdf` |
| 345 | + |
| 346 | +The simulation logic (state transitions, clock distributions) is model-specific. The Gen integration remains identical. |
| 347 | + |
| 348 | +### Examples to Adapt |
| 349 | + |
| 350 | +- **SIR epidemics**: Infection and recovery clocks with population-dependent rates |
| 351 | +- **Queueing systems**: Arrival and service clocks with queue-length feedback |
| 352 | +- **Reliability models**: Component failure and repair with dependent hazards |
| 353 | +- **Chemical kinetics**: Reaction clocks with mass-action propensities |
| 354 | + |
| 355 | +In each case, the path likelihood integral is handled automatically by CompetingClocks. |
0 commit comments