Skip to content

Commit cd1cbb2

Browse files
committed
Add script for Ising
1 parent c8da282 commit cd1cbb2

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

src/ising_prs.jl

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import Random
2+
getRNG(seed::Integer = -1) = seed >= 0 ? Random.MersenneTwister(seed) : Random.GLOBAL_RNG
3+
getRNG(seed::Union{Random.MersenneTwister,Random._GLOBAL_RNG}) = seed
4+
5+
import LightGraphs
6+
const LG = LightGraphs
7+
8+
import SimpleWeightedGraphs
9+
const SWG = SimpleWeightedGraphs
10+
11+
using GraphPlot, Colors
12+
13+
function ising_dependency_graph(
14+
dims::Tuple{T,T};
15+
periodic::Bool = false,
16+
rng = -1,
17+
)::SWG.SimpleWeightedGraph{Int64,Float64} where {T}
18+
rng = getRNG(rng)
19+
g = SWG.SimpleWeightedGraph(LG.grid(dims, periodic = periodic))
20+
for e in LG.edges(g)
21+
i, j = Tuple(e)
22+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
23+
end
24+
return g
25+
end
26+
27+
"""
28+
Find bad
29+
U_ij > exp(- |J| (1 - sign(J) x_i x_j))
30+
log(U_ij) > - |J| (1 - sign(J) x_i x_j)
31+
sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
32+
"""
33+
function ising_find_bad_states(
34+
g::SWG.SimpleWeightedGraph{T, U},
35+
states::Vector{T},
36+
J::U;
37+
rng = -1,
38+
)::Set{T} where {T, U}
39+
40+
@assert LG.nv(g) == length(states)
41+
42+
rng = getRNG(rng)
43+
44+
_2absJ, sign_J = -2.0 * abs(J), sign(J)
45+
46+
bad = Set{T}()
47+
for e LG.edges(g)
48+
i, j = Tuple(e)
49+
# Break constraint:
50+
# log(U_ij) > - |J| (1 - sign(J) x_i x_j)
51+
# <=> sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
52+
if ((sign_J * states[i] * states[j]) < 0.0) & (log(e.weight) > _2absJ)
53+
union!(bad, i, j)
54+
end
55+
end
56+
return bad
57+
end
58+
59+
"""
60+
Find Res
61+
"""
62+
function ising_find_states_to_resample!(
63+
g::SWG.SimpleWeightedGraph{T, U},
64+
states::Vector{T},
65+
J::U;
66+
rng = -1
67+
)::Set{T} where {T, U}
68+
69+
rng = getRNG(rng)
70+
sign_J = sign(J)
71+
R = ising_find_bad_states(g, states, J, rng = rng)
72+
∂R, ∂R_tmp = copy(R), Set{T}()
73+
while !isempty(∂R)
74+
for i ∂R
75+
for j LG.neighbors(g, i)
76+
# Break constraint:
77+
# log(U_ij) > - |J| (1 - sign(J) x_i x_j)
78+
# <=> sign(J) x_i x_j < 0 & log(U_ij) > - 2 |J|
79+
if j R
80+
if (sign_J * states[i] * states[j]) < 0.0
81+
# U_ij can be increased
82+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
83+
end
84+
else # if j ∉ R
85+
# x_j can be flipped to make sign(J) x_i x_j < 0
86+
union!(R, j)
87+
union!(∂R_tmp, j)
88+
# followed by an increase of U_ij
89+
@inbounds g.weights[i, j] = g.weights[j, i] = rand(rng)
90+
end
91+
end
92+
end
93+
∂R, ∂R_tmp = ∂R_tmp, Set{T}()
94+
end
95+
return R
96+
end
97+
98+
"""
99+
Resample
100+
"""
101+
function ising_sample_states!(
102+
states::Vector{T},
103+
res_ind::Set{T},
104+
probas::Vector{Float64};
105+
rng = -1,
106+
) where {T}
107+
n = length(states)
108+
@assert n == length(probas)
109+
rng = getRNG(rng)
110+
for i in res_ind
111+
states[i] = rand(rng) < probas[i] ? 1 : -1
112+
end
113+
end
114+
115+
sigmoid(x) = @. 1 / (1 + exp(-x))
116+
117+
function ising_prs(
118+
dims::Tuple{T, T},
119+
h::Vector{U},
120+
J::U;
121+
periodic::Bool=false,
122+
rng=-1,
123+
) where {T, U}
124+
rng = getRNG(rng)
125+
126+
g = ising_dependency_graph(dims, periodic=periodic, rng=rng)
127+
128+
n = LG.nv(g)
129+
states = Vector{Int64}(undef, n)
130+
probas = sigmoid.(2.0 .* h)
131+
res = Set{T}(1:LG.nv(g))
132+
133+
cnt = 0
134+
while !isempty(res)
135+
ising_sample_states!(states, res, probas, rng=rng)
136+
res = ising_find_states_to_resample!(g, states, J, rng = rng)
137+
cnt += 1
138+
end
139+
140+
return g, states, cnt
141+
end
142+
143+
function ising_prs(
144+
dims::Tuple{T, T},
145+
h::U,
146+
J::U;
147+
periodic::Bool = false,
148+
rng = -1,
149+
) where {T, U}
150+
h_vec = fill(h, prod(dims))
151+
return ising_prs(dims, h_vec, J, periodic=periodic, rng=rng)
152+
end
153+
154+
function plot_ising(g, dims, state)
155+
pos = collect(Iterators.product(1:dims[1], 1:dims[2]))[:]
156+
locs_x, locs_y = map(x->x[1], pos), map(x->x[2], pos)
157+
158+
col_nodes = ifelse.(state .== 1, colorant"gray", colorant"white")
159+
160+
p = gplot(g,
161+
locs_x, reverse(locs_y),
162+
nodefillc=col_nodes,
163+
# nodelabel=LG.vertices(g),
164+
# arrowlengthfrac=0.05
165+
# edgestrokec=col_edges
166+
)
167+
display(p)
168+
end
169+
170+
171+
dims = (14, 14) # if > (14, 14) the display becomes all black, don't know why !
172+
H, J = 0.0, -0.02 # Use Float
173+
174+
periodic = false
175+
seed = -1
176+
g, config, cnt = ising_prs(dims, H, J; periodic=periodic, rng=seed)
177+
178+
plot_ising(g, dims, config)
179+
180+
println("Number of resampling steps")
181+
println(cnt)

0 commit comments

Comments
 (0)