Skip to content

Commit 50d30c9

Browse files
committed
Package code from notebook
1 parent 30650fd commit 50d30c9

17 files changed

+1198
-316
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.1.0"
55

66
[deps]
77
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
8+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
89
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
910
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
1011
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"

src/PartialRejectionSampling.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,38 @@ const SWG = SimpleWeightedGraphs
1313

1414
using Plots, GraphPlot, Colors
1515

16-
include("ising.jl")
16+
using Distances
17+
18+
abstract type AbstractPointProcess{T} end
19+
Base.eltype(pp::AbstractPointProcess{T}) where {T} = T
20+
abstract type AbstractSpatialPointProcess{T<:Vector} <: AbstractPointProcess{T} end
21+
abstract type AbstractGraphPointProcess{T} <: AbstractPointProcess{T} end
22+
23+
function generate_sample end
24+
1725
include("utils.jl")
26+
include("window.jl")
27+
28+
# sampling
29+
include("dominated_cftp.jl")
30+
include("grid_prs.jl")
31+
32+
# Spatial point processes
33+
include("strauss.jl")
34+
include("hard_core_spatial.jl")
35+
36+
include("hard_core_graph.jl")
37+
include("ising.jl")
38+
39+
# Graph point processes
40+
include("rooted_spanning_forest.jl")
41+
include("sink_free_graph.jl")
42+
43+
# Misc
44+
include("pattern_free_string.jl")
45+
46+
# Display
47+
include("display.jl")
48+
1849

1950
end

src/display.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
function plot(
2+
pp::AbstractSpatialPointProcess,
3+
points;
4+
title=""
5+
)
6+
p = Plots.plot([0], [0],
7+
label="", legend=false,
8+
color="white",
9+
linewidth=0.0,
10+
aspect_ratio=:equal,
11+
grid=:none,
12+
title=title)
13+
14+
θ = collect(range(0, 2π, length=15))
15+
rad = pp.r / 2 # radius = interaction range / 2
16+
circ_x, circ_y = rad .* cos.(θ), rad .* sin.(θ)
17+
18+
for x in points
19+
Plots.plot!(x[1] .+ circ_x,
20+
x[2] .+ circ_y,
21+
color="black",
22+
linewidth=0.2)
23+
end
24+
25+
win = pp.window
26+
Plots.xlims!(win.c[1], win.c[1] + win.w[1])
27+
Plots.ylims!(win.c[2], win.c[2] + (win.w isa Number ? win.w[1] : win.w[2]))
28+
29+
return p
30+
end
31+
32+
function plot(
33+
ising::Ising,
34+
state
35+
)
36+
37+
pos = collect(Iterators.product(1:ising.dims[1], 1:ising.dims[2]))[:]
38+
locs_x, locs_y = map(x->x[1], pos), map(x->x[2], pos)
39+
40+
col_nodes = ifelse.(
41+
state .== 1,
42+
Colors.colorant"gray",
43+
Colors.colorant"white")
44+
45+
p = GraphPlot.gplot(
46+
ising.g,
47+
locs_x,
48+
reverse(locs_y),
49+
nodefillc=col_nodes)
50+
# nodelabel=LG.vertices(g),
51+
# arrowlengthfrac=0.05
52+
# edgestrokec=col_edges
53+
54+
return p
55+
end

src/dominated_cftp.jl

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
## Dominated Coupling From The Past (dCFTP)
2+
# - [Kendall \& Moller's orginal formulation of dCFTP](https://www.researchgate.net/publication/ 2821877_Perfect_Metropolis-Hastings_simulation_of_locally_stable_point_processes)
3+
# - Huber, Perfect Simulation
4+
# - [Kendall's notes on perfect simulation](https://warwick.ac.uk/fac/sci/statistics/staff/ academic-research/kendall/personal/ppt/428.pdf)
5+
#
6+
# Requirement: the target spatial point process must have the followingmethods
7+
# - function papangelou_conditional_intensity end
8+
# - function upper_bound_papangelou_conditional_intensity end
9+
# - function window end
10+
11+
function generate_sample_dcftp(
12+
pp::AbstractSpatialPointProcess{T};
13+
n₀::Int64=1,
14+
win::Union{Nothing,AbstractWindow}=nothing,
15+
rng=-1
16+
)::Vector{T} where {T}
17+
18+
@assert n₀ >= 1
19+
rng = getRNG(rng)
20+
21+
win_ = win === nothing ? window(pp) : win
22+
β = upper_bound_papangelou_conditional_intensity(pp)
23+
birth_rate = β * volume(win_)
24+
25+
# Dominating process
26+
k = rand(rng, Distributions.Poisson(birth_rate))
27+
D = Set{T}(rand(win_; rng=rng) for _ in 1:k)
28+
29+
M = Float64[] # Marking process
30+
R = T[] # Recording process
31+
32+
steps = -1:-1:-n₀
33+
while true
34+
backward_update!(D, M, R, steps, birth_rate, win_; rng=rng)
35+
coupling, L = forward_coupling(D, M, R, pp, β)
36+
coupling && return collect(L)
37+
steps = (steps.stop-1):-1:(2*steps.stop)
38+
end
39+
end
40+
41+
function backward_update!(
42+
D::Set{T}, # Dominating process
43+
M::Vector{Float64}, # Marking process
44+
R::Vector{T}, # Recording process
45+
steps::StepRange, # Number of backward steps
46+
birth_rate::Real,
47+
window::AbstractWindow;
48+
rng=-1
49+
) where {T}
50+
rng = getRNG(rng)
51+
for _ in steps
52+
card_D = length(D)
53+
if rand(rng) < card_D / (birth_rate + card_D)
54+
# forward death (delete) ≡ backward birth (pushfirst)
55+
x = rand(rng, D)
56+
delete!(D, x)
57+
pushfirst!(R, x)
58+
pushfirst!(M, rand(rng))
59+
else
60+
# forward birth (push) ≡ backward death (pushfirst)
61+
x = rand(window; rng)
62+
push!(D, x)
63+
pushfirst!(R, x)
64+
pushfirst!(M, 0.0)
65+
end
66+
end
67+
end
68+
69+
function forward_coupling(
70+
D::Set{T}, # Dominating process
71+
M::Vector{Float64}, # Marking process
72+
R::Vector{T}, # Recording process
73+
pp::AbstractSpatialPointProcess{T},
74+
β::Real # Upper bound on papangelou conditional intensity
75+
) where {T}
76+
# L ⊆ X ⊆ U ⊆ D, where X is the target process
77+
L, U = empty(D), copy(D)
78+
for (m, x) in zip(M, R)
79+
if m > 0 # if birth occured in D
80+
if isrepulsive(pp)
81+
if m < papangelou_conditional_intensity(pp, x, U) / β
82+
push!(L, x)
83+
push!(U, x)
84+
elseif m < papangelou_conditional_intensity(pp, x, L) / β
85+
push!(U, x)
86+
end
87+
elseif isattractive(pp)
88+
if m < papangelou_conditional_intensity(pp, x, L) / β
89+
push!(L, x)
90+
push!(U, x)
91+
elseif m < papangelou_conditional_intensity(pp, x, U) / β
92+
push!(U, x)
93+
end
94+
else
95+
error("The current implementation of Dominated Coupling From The Past "
96+
* "requires the target point process to be attractive or repulsive.")
97+
end
98+
else # if death occured in D
99+
delete!(L, x)
100+
delete!(U, x)
101+
end
102+
end
103+
# Check coalescence L = U (Since L ⊆ U: L = U ⟺ |L| = |U|)
104+
return length(L) == length(U), L
105+
end

src/grid_prs.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Implementation of Grid Partial Rejection Sampling of [Moka, Sarat B. and Kroese, Dirk P. (2020)](https://espace.library.uq.edu.au/view/UQ:d924abb)
3+
"""
4+
5+
abstract type AbstractCellGridPRS end
6+
7+
window(cell::AbstractCellGridPRS) = cell.window
8+
dimension(cell::AbstractCellGridPRS) = dimension(window(cell))
9+
10+
mutable struct GraphCellGridPRS{T} <: AbstractCellGridPRS
11+
window::GraphNode{T}
12+
value::T
13+
end
14+
15+
mutable struct SpatialCellGridPRS{T<:Vector{Float64}} <: AbstractCellGridPRS
16+
window::Union{RectangleWindow,SquareWindow}
17+
value::Vector{T}
18+
end
19+
20+
Base.isempty(cell::SpatialCellGridPRS) = isempty(cell.value)
21+
Base.iterate(cell::AbstractCellGridPRS, state=1) = iterate(cell.value, state)
22+
23+
function generate_sample_grid_prs(
24+
pp::AbstractPointProcess{T};
25+
rng=-1
26+
)::Vector{T} where {T}
27+
rng = getRNG(rng)
28+
g = weighted_interaction_graph(pp; rng=rng)
29+
30+
cells = initialize_cells(pp, LG.nv(g))
31+
resample_indices = Set(1:length(cells))
32+
33+
while !isempty(resample_indices)
34+
generate_sample!(cells, resample_indices, pp; rng=rng)
35+
resample_indices = find_cells_to_resample_indices!(g, cells, pp; rng=rng)
36+
end
37+
return vcat(getfield.(cells, :value)...)
38+
end
39+
40+
@doc raw"""
41+
The dependency graph between the cells is a weighted king graph
42+
where each edge ``\{i,j\}`` is associated to an event and carries a uniform mark ``U_{ij}``.
43+
Note https://en.wikipedia.org/wiki/King%27s_graph
44+
"""
45+
function weighted_interaction_graph(
46+
pp::AbstractSpatialPointProcess;
47+
rng=-1
48+
)::SWG.SimpleWeightedGraph
49+
rng = getRNG(rng)
50+
g = SWG.SimpleWeightedGraph(king_graph(ceil(Int, inv(pp.r))))
51+
for e in LG.edges(g)
52+
i, j = Tuple(e)
53+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, weighttype(g))
54+
end
55+
return g
56+
end
57+
58+
function generate_sample!(
59+
cell::AbstractCellGridPRS,
60+
pp::AbstractPointProcess;
61+
rng=-1
62+
)
63+
rng = getRNG(rng)
64+
cell.value = generate_sample(pp; win=cell.window, rng=rng)
65+
end
66+
67+
function generate_sample!(
68+
cells::Vector{T},
69+
indices,
70+
pp::AbstractPointProcess;
71+
rng=-1
72+
) where {T<:AbstractCellGridPRS}
73+
rng = getRNG(rng)
74+
for i in indices
75+
generate_sample!(cells[i], pp; rng=rng)
76+
end
77+
end
78+
79+
@doc raw"""
80+
Identify bad events and return the corresponding cells' index.
81+
An event ``\{i,j\}`` is said to be \"bad\"
82+
83+
```math
84+
\left\{U_{ij} > \exp \left[ -\sum_{x \in C_i} \sum_{y \in C_j} V(x,y) \right] \right\}
85+
```
86+
87+
where ``U_{ij}`` is stored as the weight of edge ``\{i,j\}`` in the dependency graph ``g``.
88+
Note: when a bad event occurs, the corresponding ``U_{ij}`` is resampled hence the "!"
89+
"""
90+
function find_bad_cells_indices!(
91+
g::SWG.SimpleWeightedGraph{T,U},
92+
cells::Vector{V},
93+
pp::AbstractPointProcess;
94+
rng=-1
95+
)::Set{T} where {T,U,V<:AbstractCellGridPRS}
96+
rng = getRNG(rng)
97+
bad = Set{T}()
98+
for e LG.edges(g)
99+
i, j = Tuple(e)
100+
if e.weight > gibbs_interaction(pp, cells[i], cells[j])
101+
union!(bad, i, j)
102+
# resample U_ij associated to the bad event
103+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U)
104+
end
105+
end
106+
return bad
107+
end
108+
109+
"""
110+
Identify which events need to be resampled and return the corresponding cells' index.
111+
This is the core of the Partial Rejection Sampling algorithm.
112+
Note: when an event needs to be resampled, the corresponding mark ``U_{ij}`` is resampled hence the "!"
113+
"""
114+
function find_cells_to_resample_indices!(
115+
g::SWG.SimpleWeightedGraph{T,U},
116+
cells::Vector{V},
117+
pp::AbstractPointProcess;
118+
rng=-1
119+
)::Set{T} where {T,U,V<:AbstractCellGridPRS}
120+
rng = getRNG(rng)
121+
R = find_bad_cells_indices!(g, cells, pp; rng=rng)
122+
∂R, ∂R_tmp = copy(R), empty(R)
123+
while !isempty(∂R)
124+
for i ∂R
125+
isempty(cells[i]) && continue
126+
for j LG.neighbors(g, i)
127+
if j R
128+
if is_inner_interaction_possible(pp, cells[i], cells[j])
129+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U)
130+
end
131+
elseif is_outer_interaction_possible(pp, cells[i], cells[j])
132+
union!(R, j)
133+
union!(∂R_tmp, j)
134+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng, U)
135+
end
136+
end
137+
end
138+
∂R, ∂R_tmp = ∂R_tmp, empty(∂R)
139+
end
140+
return R
141+
end
142+
143+
## Spatial point processes
144+
145+
function initialize_cells(
146+
spp::AbstractSpatialPointProcess{T},
147+
size::Integer
148+
)::Vector{SpatialCellGridPRS{T}} where {T}
149+
cells = Vector{SpatialCellGridPRS{T}}(undef, size)
150+
k = ceil(Int, inv(spp.r))
151+
for i in eachindex(cells)
152+
c_y, c_x = divrem(i-1, k)
153+
c = spp.window.c + spp.r .* [c_x, c_y]
154+
win_i = spatial_window(c, @. min(spp.r, spp.window.w - c))
155+
cells[i] = SpatialCellGridPRS(win_i, T[])
156+
end
157+
return cells
158+
end
159+
160+
function is_inner_interaction_possible(
161+
spp::AbstractSpatialPointProcess,
162+
cell_i::SpatialCellGridPRS,
163+
cell_j::SpatialCellGridPRS
164+
)::Bool
165+
return false
166+
end
167+
168+
@doc raw"""
169+
Assuming ``C_i`` and ``C_j`` are neighboring cells,
170+
Given the configuration of ``C_i``, check whether an assignment of ``C_j`` can induce a bad event ``\{i, j\}``.
171+
"""
172+
function is_outer_interaction_possible(
173+
spp::AbstractSpatialPointProcess,
174+
cell_i::SpatialCellGridPRS,
175+
cell_j::SpatialCellGridPRS
176+
)::Bool
177+
isempty(cell_i) && return false
178+
179+
win_i, win_j = window(cell_i), window(cell_j)
180+
# If i-j is a vertical or horizontal neighborhood
181+
any(win_i.c .== win_j.c) && return true
182+
183+
# If i-j is a diagonal neighborhood
184+
c_ij = any(win_i.c .< win_j.c) ? win_j.c : win_i.c
185+
return any(Distances.euclidean(c_ij, x) < spp.r for x in cell_i)
186+
end

0 commit comments

Comments
 (0)