Skip to content

File tree

1 file changed

+52
-54
lines changed

1 file changed

+52
-54
lines changed

include/openjij/updater/single_integer_move.hpp

Lines changed: 52 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct HeatBathUpdater {
6868
std::int64_t GenerateNewValue(SystemType &sa_system, const std::int64_t index,
6969
const double T, const double _progress) {
7070
if (sa_system.OnlyMultiLinearCoeff(index)) {
71-
return ForBilinear(sa_system, index, T, _progress);
71+
return ForBilinear(sa_system, index, T, _progress);
7272
} else {
7373
return ForAll(sa_system, index, T, _progress);
7474
}
@@ -77,25 +77,36 @@ struct HeatBathUpdater {
7777
template <typename SystemType>
7878
std::int64_t ForAll(SystemType &sa_system, const std::int64_t index,
7979
const double T, const double _progress) {
80+
81+
const auto [max_weight_state_value, min_dE] = sa_system.GetMinEnergyDifference(index);
8082
const auto &var = sa_system.GetState()[index];
8183
const double beta = 1.0 / T;
84+
double z = 0.0;
85+
86+
// Calculate the partition function
87+
for (std::int64_t i = 0; i < var.num_states; ++i) {
88+
const double dE = sa_system.GetEnergyDifference(index, var.GetValueFromState(i)) - min_dE;
89+
z += std::exp(-beta * dE);
90+
}
91+
92+
// Select a state based on the partition function
8293
std::int64_t selected_state_number = -1;
83-
double max_z = -std::numeric_limits<double>::infinity();
94+
double cumulative_prob = 0.0;
95+
const double rand = dist(sa_system.random_number_engine) * z;
8496

8597
for (std::int64_t i = 0; i < var.num_states; ++i) {
86-
const double g =
87-
-std::log(-std::log(dist(sa_system.random_number_engine)));
88-
const double z = -beta * sa_system.GetEnergyDifference(
89-
index, var.GetValueFromState(i)) +
90-
g;
91-
if (z > max_z) {
92-
max_z = z;
98+
const double dE = sa_system.GetEnergyDifference(index, var.GetValueFromState(i)) - min_dE;
99+
cumulative_prob += std::exp(-beta * dE);
100+
if (rand <= cumulative_prob) {
93101
selected_state_number = i;
102+
break;
94103
}
95104
}
105+
96106
if (selected_state_number == -1) {
97107
throw std::runtime_error("No state selected.");
98108
}
109+
99110
return var.GetValueFromState(selected_state_number);
100111
}
101112

@@ -137,66 +148,53 @@ struct SuwaTodoUpdater {
137148
const double T, const double _progress) {
138149
const auto &var = sa_system.GetState()[index];
139150
const std::int64_t max_num_state = var.num_states;
140-
141151
std::vector<double> weight_list(max_num_state, 0.0);
142-
std::vector<double> sum_weight_list(max_num_state + 1, 0.0);
143-
std::vector<double> dE_list(max_num_state, 0.0);
144-
145-
const auto [max_weight_state_value, min_dE] =
146-
sa_system.GetMinEnergyDifference(index);
147-
const std::int64_t max_weight_state =
148-
max_weight_state_value - var.lower_bound;
152+
std::vector<double> sum_weight_list(max_num_state, 0.0);
149153

150-
for (std::int64_t state = 0; state < var.num_states; ++state) {
151-
const std::int64_t value = var.GetValueFromState(state);
152-
dE_list[state] = sa_system.GetEnergyDifference(index, value) - min_dE;
153-
}
154-
155-
weight_list[0] = std::exp(-dE_list[max_weight_state] / T);
156-
sum_weight_list[1] = weight_list[0];
154+
const auto [max_weight_state_value, min_dE] = sa_system.GetMinEnergyDifference(index);
155+
const auto max_weight_state = var.GetStateFromValue(max_weight_state_value);
157156

158-
for (std::int64_t state = 1; state < var.num_states; ++state) {
159-
if (state == max_weight_state) {
160-
weight_list[state] = std::exp(-dE_list[0] / T);
161-
} else {
162-
weight_list[state] = std::exp(-dE_list[state] / T);
163-
}
164-
sum_weight_list[state + 1] = sum_weight_list[state] + weight_list[state];
157+
for (std::int64_t i = 0; i < max_num_state; ++i) {
158+
const std::int64_t state = (i == 0) ? max_weight_state : ((i == max_weight_state) ? 0 : i);
159+
const double dE = sa_system.GetEnergyDifference(index, var.GetValueFromState(state)) - min_dE;
160+
weight_list[i] = std::exp(-dE / T);
161+
sum_weight_list[i] = (i == 0) ? weight_list[i] : sum_weight_list[i - 1] + weight_list[i];
165162
}
166163

167-
sum_weight_list[0] = sum_weight_list[var.num_states];
168-
const std::int64_t current_state = var.value - var.lower_bound;
169-
std::int64_t now_state;
164+
std::int64_t current_state = var.GetStateFromValue(var.value);
170165
if (current_state == 0) {
171-
now_state = max_weight_state;
166+
current_state = max_weight_state;
172167
} else if (current_state == max_weight_state) {
173-
now_state = 0;
174-
} else {
175-
now_state = current_state;
168+
current_state = 0;
176169
}
177170

171+
const double w_0 = weight_list[0];
172+
const double w_c = weight_list[current_state];
173+
const double sum_w_c = sum_weight_list[current_state];
174+
const double rand = dist(sa_system.random_number_engine) * w_c;
175+
std::int64_t selected_state = -1;
178176
double prob_sum = 0.0;
179-
const double rand = dist(sa_system.random_number_engine) * weight_list[now_state];
180-
181-
for (std::int64_t j = 0; j < var.num_states; ++j) {
182-
const double d_ij = sum_weight_list[now_state + 1] - sum_weight_list[j] +
183-
sum_weight_list[1];
184-
prob_sum += std::max(0.0, std::min({d_ij, weight_list[now_state] + weight_list[j] - d_ij,
185-
weight_list[now_state], weight_list[j]}));
186-
if (rand < prob_sum) {
187-
std::int64_t new_state;
188-
if (j == max_weight_state) {
189-
new_state = 0;
190-
} else if (j == 0) {
191-
new_state = max_weight_state;
177+
178+
for (std::int64_t j = 0; j < max_num_state; ++j) {
179+
const double d_ij = sum_w_c - sum_weight_list[(j - 1 + max_num_state) % max_num_state] + w_0;
180+
prob_sum += std::max(0.0, std::min({d_ij, w_c + (weight_list[j] - d_ij), w_c, weight_list[j]}));
181+
if (rand <= prob_sum) {
182+
if (j == 0) {
183+
selected_state = max_weight_state;
184+
} else if (j == max_weight_state) {
185+
selected_state = 0;
192186
} else {
193-
new_state = j;
187+
selected_state = j;
194188
}
195-
return var.GetValueFromState(new_state);
189+
break;
196190
}
197191
}
198192

199-
return var.GetValueFromState(var.num_states - 1);
193+
if (selected_state == -1) {
194+
throw std::runtime_error("No state selected.");
195+
}
196+
197+
return var.GetValueFromState(selected_state);
200198
}
201199

202200
std::uniform_real_distribution<double> dist{0.0, 1.0};

0 commit comments

Comments
 (0)