|
| 1 | +""" |
| 2 | +Implementation of Grid Partial Rejection Sampling of [Moka, Sarat B. and Kroese, Dirk P. (2020)](https://espace.library.uq.edu.au/view/UQ:d924abb) |
| 3 | +""" |
| 4 | + |
| 5 | +abstract type AbstractCellGridPRS end |
| 6 | + |
| 7 | +window(cell::AbstractCellGridPRS) = cell.window |
| 8 | +dimension(cell::AbstractCellGridPRS) = dimension(window(cell)) |
| 9 | + |
| 10 | +mutable struct GraphCellGridPRS{T} <: AbstractCellGridPRS |
| 11 | + window::GraphNode{T} |
| 12 | + value::T |
| 13 | +end |
| 14 | + |
| 15 | +mutable struct SpatialCellGridPRS{T<:Vector{Float64}} <: AbstractCellGridPRS |
| 16 | + window::Union{RectangleWindow,SquareWindow} |
| 17 | + value::Vector{T} |
| 18 | +end |
| 19 | + |
| 20 | +Base.isempty(cell::SpatialCellGridPRS) = isempty(cell.value) |
| 21 | +Base.iterate(cell::AbstractCellGridPRS, state=1) = iterate(cell.value, state) |
| 22 | + |
| 23 | +function generate_sample_grid_prs( |
| 24 | + pp::AbstractPointProcess{T}; |
| 25 | + rng=-1 |
| 26 | +)::Vector{T} where {T} |
| 27 | + rng = getRNG(rng) |
| 28 | + g = weighted_interaction_graph(pp; rng=rng) |
| 29 | + |
| 30 | + cells = initialize_cells(pp, LG.nv(g)) |
| 31 | + resample_indices = Set(1:length(cells)) |
| 32 | + |
| 33 | + while !isempty(resample_indices) |
| 34 | + generate_sample!(cells, resample_indices, pp; rng=rng) |
| 35 | + resample_indices = find_cells_to_resample_indices!(g, cells, pp; rng=rng) |
| 36 | + end |
| 37 | + return vcat(getfield.(cells, :value)...) |
| 38 | +end |
| 39 | + |
| 40 | +@doc raw""" |
| 41 | +The dependency graph between the cells is a weighted king graph |
| 42 | +where each edge ``\{i,j\}`` is associated to an event and carries a uniform mark ``U_{ij}``. |
| 43 | +Note https://en.wikipedia.org/wiki/King%27s_graph |
| 44 | +""" |
| 45 | +function weighted_interaction_graph( |
| 46 | + pp::AbstractSpatialPointProcess; |
| 47 | + rng=-1 |
| 48 | +)::SWG.SimpleWeightedGraph |
| 49 | + rng = getRNG(rng) |
| 50 | + g = SWG.SimpleWeightedGraph(king_graph(ceil(Int, inv(pp.r)))) |
| 51 | + for e in LG.edges(g) |
| 52 | + i, j = Tuple(e) |
| 53 | + @inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, weighttype(g)) |
| 54 | + end |
| 55 | + return g |
| 56 | +end |
| 57 | + |
| 58 | +function generate_sample!( |
| 59 | + cell::AbstractCellGridPRS, |
| 60 | + pp::AbstractPointProcess; |
| 61 | + rng=-1 |
| 62 | +) |
| 63 | + rng = getRNG(rng) |
| 64 | + cell.value = generate_sample(pp; win=cell.window, rng=rng) |
| 65 | +end |
| 66 | + |
| 67 | +function generate_sample!( |
| 68 | + cells::Vector{T}, |
| 69 | + indices, |
| 70 | + pp::AbstractPointProcess; |
| 71 | + rng=-1 |
| 72 | +) where {T<:AbstractCellGridPRS} |
| 73 | + rng = getRNG(rng) |
| 74 | + for i in indices |
| 75 | + generate_sample!(cells[i], pp; rng=rng) |
| 76 | + end |
| 77 | +end |
| 78 | + |
| 79 | +@doc raw""" |
| 80 | +Identify bad events and return the corresponding cells' index. |
| 81 | +An event ``\{i,j\}`` is said to be \"bad\" |
| 82 | +
|
| 83 | +```math |
| 84 | + \left\{U_{ij} > \exp \left[ -\sum_{x \in C_i} \sum_{y \in C_j} V(x,y) \right] \right\} |
| 85 | +``` |
| 86 | +
|
| 87 | +where ``U_{ij}`` is stored as the weight of edge ``\{i,j\}`` in the dependency graph ``g``. |
| 88 | +Note: when a bad event occurs, the corresponding ``U_{ij}`` is resampled hence the "!" |
| 89 | +""" |
| 90 | +function find_bad_cells_indices!( |
| 91 | + g::SWG.SimpleWeightedGraph{T,U}, |
| 92 | + cells::Vector{V}, |
| 93 | + pp::AbstractPointProcess; |
| 94 | + rng=-1 |
| 95 | +)::Set{T} where {T,U,V<:AbstractCellGridPRS} |
| 96 | + rng = getRNG(rng) |
| 97 | + bad = Set{T}() |
| 98 | + for e ∈ LG.edges(g) |
| 99 | + i, j = Tuple(e) |
| 100 | + if e.weight > gibbs_interaction(pp, cells[i], cells[j]) |
| 101 | + union!(bad, i, j) |
| 102 | + # resample U_ij associated to the bad event |
| 103 | + @inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U) |
| 104 | + end |
| 105 | + end |
| 106 | + return bad |
| 107 | +end |
| 108 | + |
| 109 | +""" |
| 110 | +Identify which events need to be resampled and return the corresponding cells' index. |
| 111 | +This is the core of the Partial Rejection Sampling algorithm. |
| 112 | +Note: when an event needs to be resampled, the corresponding mark ``U_{ij}`` is resampled hence the "!" |
| 113 | +""" |
| 114 | +function find_cells_to_resample_indices!( |
| 115 | + g::SWG.SimpleWeightedGraph{T,U}, |
| 116 | + cells::Vector{V}, |
| 117 | + pp::AbstractPointProcess; |
| 118 | + rng=-1 |
| 119 | +)::Set{T} where {T,U,V<:AbstractCellGridPRS} |
| 120 | + rng = getRNG(rng) |
| 121 | + R = find_bad_cells_indices!(g, cells, pp; rng=rng) |
| 122 | + ∂R, ∂R_tmp = copy(R), empty(R) |
| 123 | + while !isempty(∂R) |
| 124 | + for i ∈ ∂R |
| 125 | + isempty(cells[i]) && continue |
| 126 | + for j ∈ LG.neighbors(g, i) |
| 127 | + if j ∈ R |
| 128 | + if is_inner_interaction_possible(pp, cells[i], cells[j]) |
| 129 | + @inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U) |
| 130 | + end |
| 131 | + elseif is_outer_interaction_possible(pp, cells[i], cells[j]) |
| 132 | + union!(R, j) |
| 133 | + union!(∂R_tmp, j) |
| 134 | + @inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U) |
| 135 | + end |
| 136 | + end |
| 137 | + end |
| 138 | + ∂R, ∂R_tmp = ∂R_tmp, empty(∂R) |
| 139 | + end |
| 140 | + return R |
| 141 | +end |
| 142 | + |
| 143 | +## Spatial point processes |
| 144 | + |
| 145 | +function initialize_cells( |
| 146 | + spp::AbstractSpatialPointProcess{T}, |
| 147 | + size::Integer |
| 148 | +)::Vector{SpatialCellGridPRS{T}} where {T} |
| 149 | + cells = Vector{SpatialCellGridPRS{T}}(undef, size) |
| 150 | + k = ceil(Int, inv(spp.r)) |
| 151 | + for i in eachindex(cells) |
| 152 | + c_y, c_x = divrem(i-1, k) |
| 153 | + c = spp.window.c + spp.r .* [c_x, c_y] |
| 154 | + win_i = spatial_window(c, @. min(spp.r, spp.window.w - c)) |
| 155 | + cells[i] = SpatialCellGridPRS(win_i, T[]) |
| 156 | + end |
| 157 | + return cells |
| 158 | +end |
| 159 | + |
| 160 | +function is_inner_interaction_possible( |
| 161 | + spp::AbstractSpatialPointProcess, |
| 162 | + cell_i::SpatialCellGridPRS, |
| 163 | + cell_j::SpatialCellGridPRS |
| 164 | +)::Bool |
| 165 | + return false |
| 166 | +end |
| 167 | + |
| 168 | +@doc raw""" |
| 169 | +Assuming ``C_i`` and ``C_j`` are neighboring cells, |
| 170 | +Given the configuration of ``C_i``, check whether an assignment of ``C_j`` can induce a bad event ``\{i, j\}``. |
| 171 | +""" |
| 172 | +function is_outer_interaction_possible( |
| 173 | + spp::AbstractSpatialPointProcess, |
| 174 | + cell_i::SpatialCellGridPRS, |
| 175 | + cell_j::SpatialCellGridPRS |
| 176 | +)::Bool |
| 177 | + isempty(cell_i) && return false |
| 178 | + |
| 179 | + win_i, win_j = window(cell_i), window(cell_j) |
| 180 | + # If i-j is a vertical or horizontal neighborhood |
| 181 | + any(win_i.c .== win_j.c) && return true |
| 182 | + |
| 183 | + # If i-j is a diagonal neighborhood |
| 184 | + c_ij = any(win_i.c .< win_j.c) ? win_j.c : win_i.c |
| 185 | + return any(Distances.euclidean(c_ij, x) < spp.r for x in cell_i) |
| 186 | +end |
0 commit comments