Skip to content

Commit fce95a3

Browse files
committed
Package ising code
1 parent 4ad688f commit fce95a3

File tree

3 files changed

+202
-113
lines changed

3 files changed

+202
-113
lines changed

src/PartialRejectionSampling.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
module PartialRejectionSampling
22

3-
# Write your package code here.
3+
using LinearAlgebra
4+
const LA = LinearAlgebra
5+
6+
using Random, Distributions
7+
8+
using LightGraphs
9+
const LG = LightGraphs
10+
11+
using SimpleWeightedGraphs
12+
const SWG = SimpleWeightedGraphs
13+
14+
using Plots, GraphPlot, Colors
15+
16+
include("ising.jl")
17+
include("utils.jl")
418

519
end

src/ising.jl

Lines changed: 183 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,182 @@
1-
import Random
2-
getRNG(seed::Integer = -1) = seed >= 0 ? Random.MersenneTwister(seed) : Random.GLOBAL_RNG
3-
getRNG(seed::Union{Random.MersenneTwister,Random._GLOBAL_RNG}) = seed
1+
struct Ising
2+
"Grid dimension"
3+
dims::AbstractVector{Int}
4+
"Interaction graph"
5+
g::LG.SimpleGraph{Int}
6+
"Magnetization"
7+
h::Union{Real, AbstractVector{Real}}
8+
"Interaction"
9+
J::Real
10+
end
411

12+
function Ising(
13+
dims::AbstractVector{Int},
14+
periodic::Bool,
15+
h::Union{Real, AbstractVector{Real}},
16+
J::Real
17+
)
18+
if isa(h, AbstractVector)
19+
@assert length(h) != prod(dims) "length(h) != prod(dims)"
20+
end
21+
g = LG.grid(dims; periodic=periodic)
22+
return Ising(dims, g, h, J)
23+
end
524

6-
import LightGraphs
7-
const LG = LightGraphs
25+
function Ising(
26+
dims::AbstractVector{Int},
27+
periodic::Bool,
28+
J::Real
29+
)
30+
return Ising(dims, periodic, 0.0, J)
31+
end
832

9-
import SimpleWeightedGraphs
10-
const SWG = SimpleWeightedGraphs
33+
function energy(
34+
ising::Ising,
35+
state::AbstractVector{Int}
36+
)::Real
37+
@assert all(x -> x [-1, 1], state) "state values not all {-1, 1}"
1138

12-
using GraphPlot, Colors
39+
E = 0.0
1340

14-
function ising_dependency_graph(
15-
dims::Vector{T};
16-
periodic::Bool = false,
17-
rng = -1,
18-
)::SWG.SimpleWeightedGraph{Int64,Float64} where {T}
19-
rng = getRNG(rng)
20-
g = SWG.SimpleWeightedGraph(LG.grid(dims, periodic = periodic))
21-
for e in LG.edges(g)
41+
h = ising.h
42+
if isa(h, Real)
43+
E = - h * sum(state)
44+
else # isa(h, AbstractVector)
45+
E = - h' * state
46+
end
47+
48+
J = ising.J
49+
for e in LG.edges(ising.g)
2250
i, j = Tuple(e)
23-
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
51+
xᵢ, xⱼ = state[i], state[j]
52+
E += J * xᵢ * xⱼ
2453
end
25-
return g
54+
return E
2655
end
2756

28-
"""
29-
Find bad
30-
U_ij > exp(- |J| (1 - sign(J) x_i x_j))
31-
log(U_ij) > - |J| (1 - sign(J) x_i x_j)
32-
sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
33-
"""
34-
function ising_find_bad_states(
35-
g::SWG.SimpleWeightedGraph{T, U},
36-
states::Vector{T},
37-
J::U;
38-
rng = -1,
39-
)::Set{T} where {T, U}
57+
function plot(
58+
ising::Ising,
59+
state::AbstractVector{Int}
60+
)
61+
62+
pos = collect(Iterators.product(1:ising.dims[1], 1:ising.dims[2]))[:]
63+
locs_x, locs_y = map(x->x[1], pos), map(x->x[2], pos)
64+
65+
col_nodes = ifelse.(
66+
state .== 1,
67+
Colors.colorant"gray",
68+
Colors.colorant"white")
69+
70+
p = GraphPlot.gplot(
71+
ising.g,
72+
locs_x,
73+
reverse(locs_y),
74+
nodefillc=col_nodes)
75+
# nodelabel=LG.vertices(g),
76+
# arrowlengthfrac=0.05
77+
# edgestrokec=col_edges
78+
79+
display(p)
80+
end
81+
82+
function sample_state!(
83+
state::AbstractVector{T},
84+
ising::Ising,
85+
i::T;
86+
rng=-1
87+
) where {T<:Int}
88+
rng = getRNG(rng)
89+
hᵢ = isa(ising.h, Real) ? ising.h : ising.h[i]
90+
state[i] = rand(rng) < sigmoid(hᵢ) ? 1 : -1
91+
end
92+
93+
function sample_states!(
94+
state::AbstractVector{T},
95+
ising::Ising,
96+
indices::Set{T};
97+
rng=-1
98+
) where {T<:Int}
99+
for i in indices
100+
sample_state!(state, ising, i; rng=rng)
101+
end
102+
end
103+
104+
function sample_conditional!(
105+
state::AbstractVector{T},
106+
ising::Ising,
107+
i::T;
108+
rng=-1
109+
) where {T<:Int}
110+
rng = getRNG(rng)
111+
hᵢ = isa(ising.h, Real) ? ising.h : ising.h[i]
112+
proba = sigmoid(hᵢ + ising.J * sum(state[j] for j in LG.neighbors(ising.g, i)))
113+
state[i] = rand(rng) < proba ? 1 : -1
114+
end
40115

41-
@assert LG.nv(g) == length(states)
116+
## Partial Rejection Sampling sampler (mix of Guo Jerrum Liu and Moka Kroese)
42117

118+
function weighted_graph(
119+
ising::Ising;
120+
rng=-1
121+
)::SWG.SimpleWeightedGraph{Int64,Float64}
43122
rng = getRNG(rng)
123+
wg = SWG.SimpleWeightedGraph(ising.g)
124+
for e in LG.edges(wg)
125+
i, j = Tuple(e)
126+
@inbounds wg.weights[i, j] = wg.weights[j, i] = rand(rng)
127+
end
128+
return wg
129+
end
130+
131+
function bad_states(
132+
ising::Ising,
133+
state::AbstractVector{T},
134+
wg::SWG.SimpleWeightedGraph{T, U}
135+
)::Set{T} where {T, U}
44136

45-
_2absJ, sign_J = -2.0 * abs(J), sign(J)
137+
_2absJ, sign_J = -2.0 * abs(ising.J), sign(ising.J)
46138

47139
bad = Set{T}()
48-
for e LG.edges(g)
140+
for e LG.edges(wg)
49141
i, j = Tuple(e)
50142
# Break constraint:
51143
# log(U_ij) > - |J| (1 - sign(J) x_i x_j)
52144
# <=> sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
53-
if ((sign_J * states[i] * states[j]) < 0.0) & (log(e.weight) > _2absJ)
145+
if ((sign_J * state[i] * state[j]) < 0.0) & (log(e.weight) > _2absJ)
54146
union!(bad, i, j)
55147
end
56148
end
57149
return bad
58150
end
59151

60-
"""
61-
Find Res
62-
"""
63-
function ising_find_states_to_resample!(
64-
g::SWG.SimpleWeightedGraph{T, U},
65-
states::Vector{T},
66-
J::U;
67-
rng = -1
152+
function resampling_states!(
153+
wg::SWG.SimpleWeightedGraph{T, U},
154+
ising::Ising,
155+
state::AbstractVector{T};
156+
rng=-1
68157
)::Set{T} where {T, U}
69158

70159
rng = getRNG(rng)
71-
sign_J = sign(J)
72-
R = ising_find_bad_states(g, states, J, rng = rng)
160+
sign_J = sign(ising.J)
161+
R = bad_states(ising, state, wg)
73162
∂R, ∂R_tmp = copy(R), Set{T}()
74163
while !isempty(∂R)
75164
for i ∂R
76-
for j LG.neighbors(g, i)
165+
for j LG.neighbors(wg, i)
77166
# Break constraint:
78167
# log(U_ij) > - |J| (1 - sign(J) x_i x_j)
79168
# <=> sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
80169
if j R
81-
if (sign_J * states[i] * states[j]) < 0.0
170+
if (sign_J * state[i] * state[j]) < 0.0
82171
# U_ij can be increased
83-
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
172+
@inbounds wg.weights[i, j] = wg.weights[j, i] = rand(rng)
84173
end
85174
else # if j ∉ R
86175
# x_j can be flipped to make sign(J) x_i x_j < 0
87176
union!(R, j)
88177
union!(∂R_tmp, j)
89178
# followed by an increase of U_ij
90-
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
179+
@inbounds wg.weights[i, j] = wg.weights[j, i] = rand(rng)
91180
end
92181
end
93182
end
@@ -96,85 +185,67 @@ function ising_find_states_to_resample!(
96185
return R
97186
end
98187

99-
"""
100-
Resample
101-
"""
102-
function ising_sample_states!(
103-
states::Vector{T},
104-
res_ind::Set{T},
105-
probas::Vector{Float64};
106-
rng = -1,
107-
) where {T}
108-
n = length(states)
109-
@assert n == length(probas)
110-
rng = getRNG(rng)
111-
for i in res_ind
112-
states[i] = rand(rng) < probas[i] ? 1 : -1
113-
end
114-
end
115-
116-
sigmoid(x) = @. 1 / (1 + exp(-x))
117-
118-
function ising_prs(
119-
dims::Vector{T},
120-
h::Vector{U},
121-
J::U;
122-
periodic::Bool=false,
188+
function prs(
189+
ising::Ising;
123190
rng=-1,
124-
) where {T, U}
191+
)::Tuple{AbstractVector{Int}, Int}
125192
rng = getRNG(rng)
126193

127-
g = ising_dependency_graph(dims, periodic=periodic, rng=rng)
128-
129-
n = LG.nv(g)
130-
states = Vector{Int64}(undef, n)
131-
probas = sigmoid.(2.0 .* h)
132-
res = Set{T}(1:LG.nv(g))
194+
g = weighted_graph(ising; rng=rng)
195+
state = Vector{Int}(undef, LG.nv(g))
196+
res = Set{Int}(1:LG.nv(g))
133197

134198
cnt = 0
135199
while !isempty(res)
136-
ising_sample_states!(states, res, probas, rng=rng)
137-
res = ising_find_states_to_resample!(g, states, J, rng = rng)
200+
sample_states!(state, ising, res; rng=rng)
201+
res = resampling_states!(g, ising, state; rng=rng)
138202
cnt += 1
139203
end
140204

141-
return g, states, cnt
205+
return state, cnt
142206
end
143207

144-
function ising_prs(
145-
dims::Vector{T},
146-
h::U,
147-
J::U;
148-
periodic::Bool = false,
149-
rng = -1,
150-
) where {T, U}
151-
h_vec = fill(h, prod(dims))
152-
return ising_prs(dims, h_vec, J, periodic=periodic, rng=rng)
153-
end
208+
## Perfect Gibbs sampler (Feng, Guo, Yin) https://arxiv.org/pdf/1907.06033.pdf
154209

155-
function plot_ising(g, dims, state)
156-
pos = collect(Iterators.product(1:dims[1], 1:dims[2]))[:]
157-
locs_x, locs_y = map(x->x[1], pos), map(x->x[2], pos)
210+
function bayes_filter(
211+
ising::Ising,
212+
state::AbstractVector{T},
213+
i::T,
214+
R::Set{T};
215+
rng=-1
216+
)::Bool where {T<:Int}
158217

159-
col_nodes = ifelse.(state .== 1, colorant"gray", colorant"white")
218+
∂i_R̄ = setdiff(LG.neighbors(ising.g, i), R)
219+
isempty(∂i_R̄) && return true
160220

161-
p = gplot(g,
162-
locs_x, reverse(locs_y),
163-
nodefillc=col_nodes,
164-
# nodelabel=LG.vertices(g),
165-
# arrowlengthfrac=0.05
166-
# edgestrokec=col_edges
167-
)
168-
display(p)
221+
rng = getRNG(rng)
222+
Xi_J = state[i] * ising.J
223+
acc_r = Xi_J * (-sign(Xi_J) * length(∂i_R̄) - sum(state[j] for j in ∂i_R̄))
224+
return log(rand(rng)) < acc_r
169225
end
170226

171-
dims = [14, 14] # if > (14, 14) the display becomes all black, don't know why !
172-
H, J = 0.0, -0.02 # Use Float
227+
function gibbs_perfect(
228+
ising::Ising;
229+
rng=-1
230+
)::Tuple{AbstractVector{Int}, Int}
231+
rng = getRNG(rng)
173232

174-
periodic = false
175-
seed = -1
176-
g, config, cnt = ising_prs(dims, H, J; periodic=periodic, rng=seed)
233+
n = LG.nv(ising.g)
234+
R = Set{Int}(1:n)
235+
state = Vector{Int}(undef, n)
236+
sample_states!(state, ising, R; rng=rng)
237+
238+
cnt = 1
239+
while !isempty(R)
240+
i = rand(rng, R)
241+
if bayes_filter(ising, state, i, R; rng=-1)
242+
sample_conditional!(state, ising, i; rng=rng)
243+
delete!(R, i)
244+
else
245+
union!(R, LG.neighbors(ising.g, i))
246+
end
247+
cnt += 1
248+
end
177249

178-
plot_ising(g, dims, config)
179-
println("Number of resampling steps")
180-
print(cnt)
250+
return state, cnt
251+
end

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
getRNG(seed::Integer = -1) = seed >= 0 ? Random.MersenneTwister(seed) : Random.GLOBAL_RNG
2+
getRNG(seed::Union{Random.MersenneTwister,Random._GLOBAL_RNG}) = seed
3+
4+
sigmoid(x) = @. 1 / (1 + exp(-x))

0 commit comments

Comments
 (0)