@@ -68,7 +68,7 @@ struct HeatBathUpdater {
6868 std::int64_t GenerateNewValue (SystemType &sa_system, const std::int64_t index,
6969 const double T, const double _progress) {
7070 if (sa_system.OnlyMultiLinearCoeff (index)) {
71- return ForBilinear (sa_system, index, T, _progress);
71+ return ForBilinear (sa_system, index, T, _progress);
7272 } else {
7373 return ForAll (sa_system, index, T, _progress);
7474 }
@@ -77,25 +77,36 @@ struct HeatBathUpdater {
7777 template <typename SystemType>
7878 std::int64_t ForAll (SystemType &sa_system, const std::int64_t index,
7979 const double T, const double _progress) {
80+
81+ const auto [max_weight_state_value, min_dE] = sa_system.GetMinEnergyDifference (index);
8082 const auto &var = sa_system.GetState ()[index];
8183 const double beta = 1.0 / T;
84+ double z = 0.0 ;
85+
86+ // Calculate the partition function
87+ for (std::int64_t i = 0 ; i < var.num_states ; ++i) {
88+ const double dE = sa_system.GetEnergyDifference (index, var.GetValueFromState (i)) - min_dE;
89+ z += std::exp (-beta * dE);
90+ }
91+
92+ // Select a state based on the partition function
8293 std::int64_t selected_state_number = -1 ;
83- double max_z = -std::numeric_limits<double >::infinity ();
94+ double cumulative_prob = 0.0 ;
95+ const double rand = dist (sa_system.random_number_engine ) * z;
8496
8597 for (std::int64_t i = 0 ; i < var.num_states ; ++i) {
86- const double g =
87- -std::log (-std::log (dist (sa_system.random_number_engine )));
88- const double z = -beta * sa_system.GetEnergyDifference (
89- index, var.GetValueFromState (i)) +
90- g;
91- if (z > max_z) {
92- max_z = z;
98+ const double dE = sa_system.GetEnergyDifference (index, var.GetValueFromState (i)) - min_dE;
99+ cumulative_prob += std::exp (-beta * dE);
100+ if (rand <= cumulative_prob) {
93101 selected_state_number = i;
102+ break ;
94103 }
95104 }
105+
96106 if (selected_state_number == -1 ) {
97107 throw std::runtime_error (" No state selected." );
98108 }
109+
99110 return var.GetValueFromState (selected_state_number);
100111 }
101112
@@ -137,66 +148,53 @@ struct SuwaTodoUpdater {
137148 const double T, const double _progress) {
138149 const auto &var = sa_system.GetState ()[index];
139150 const std::int64_t max_num_state = var.num_states ;
140-
141151 std::vector<double > weight_list (max_num_state, 0.0 );
142- std::vector<double > sum_weight_list (max_num_state + 1 , 0.0 );
143- std::vector<double > dE_list (max_num_state, 0.0 );
144-
145- const auto [max_weight_state_value, min_dE] =
146- sa_system.GetMinEnergyDifference (index);
147- const std::int64_t max_weight_state =
148- max_weight_state_value - var.lower_bound ;
152+ std::vector<double > sum_weight_list (max_num_state, 0.0 );
149153
150- for (std::int64_t state = 0 ; state < var.num_states ; ++state) {
151- const std::int64_t value = var.GetValueFromState (state);
152- dE_list[state] = sa_system.GetEnergyDifference (index, value) - min_dE;
153- }
154-
155- weight_list[0 ] = std::exp (-dE_list[max_weight_state] / T);
156- sum_weight_list[1 ] = weight_list[0 ];
154+ const auto [max_weight_state_value, min_dE] = sa_system.GetMinEnergyDifference (index);
155+ const auto max_weight_state = var.GetStateFromValue (max_weight_state_value);
157156
158- for (std::int64_t state = 1 ; state < var.num_states ; ++state) {
159- if (state == max_weight_state) {
160- weight_list[state] = std::exp (-dE_list[0 ] / T);
161- } else {
162- weight_list[state] = std::exp (-dE_list[state] / T);
163- }
164- sum_weight_list[state + 1 ] = sum_weight_list[state] + weight_list[state];
157+ for (std::int64_t i = 0 ; i < max_num_state; ++i) {
158+ const std::int64_t state = (i == 0 ) ? max_weight_state : ((i == max_weight_state) ? 0 : i);
159+ const double dE = sa_system.GetEnergyDifference (index, var.GetValueFromState (state)) - min_dE;
160+ weight_list[i] = std::exp (-dE / T);
161+ sum_weight_list[i] = (i == 0 ) ? weight_list[i] : sum_weight_list[i - 1 ] + weight_list[i];
165162 }
166163
167- sum_weight_list[0 ] = sum_weight_list[var.num_states ];
168- const std::int64_t current_state = var.value - var.lower_bound ;
169- std::int64_t now_state;
164+ std::int64_t current_state = var.GetStateFromValue (var.value );
170165 if (current_state == 0 ) {
171- now_state = max_weight_state;
166+ current_state = max_weight_state;
172167 } else if (current_state == max_weight_state) {
173- now_state = 0 ;
174- } else {
175- now_state = current_state;
168+ current_state = 0 ;
176169 }
177170
171+ const double w_0 = weight_list[0 ];
172+ const double w_c = weight_list[current_state];
173+ const double sum_w_c = sum_weight_list[current_state];
174+ const double rand = dist (sa_system.random_number_engine ) * w_c;
175+ std::int64_t selected_state = -1 ;
178176 double prob_sum = 0.0 ;
179- const double rand = dist (sa_system.random_number_engine ) * weight_list[now_state];
180-
181- for (std::int64_t j = 0 ; j < var.num_states ; ++j) {
182- const double d_ij = sum_weight_list[now_state + 1 ] - sum_weight_list[j] +
183- sum_weight_list[1 ];
184- prob_sum += std::max (0.0 , std::min ({d_ij, weight_list[now_state] + weight_list[j] - d_ij,
185- weight_list[now_state], weight_list[j]}));
186- if (rand < prob_sum) {
187- std::int64_t new_state;
188- if (j == max_weight_state) {
189- new_state = 0 ;
190- } else if (j == 0 ) {
191- new_state = max_weight_state;
177+
178+ for (std::int64_t j = 0 ; j < max_num_state; ++j) {
179+ const double d_ij = sum_w_c - sum_weight_list[(j - 1 + max_num_state) % max_num_state] + w_0;
180+ prob_sum += std::max (0.0 , std::min ({d_ij, w_c + (weight_list[j] - d_ij), w_c, weight_list[j]}));
181+ if (rand <= prob_sum) {
182+ if (j == 0 ) {
183+ selected_state = max_weight_state;
184+ } else if (j == max_weight_state) {
185+ selected_state = 0 ;
192186 } else {
193- new_state = j;
187+ selected_state = j;
194188 }
195- return var. GetValueFromState (new_state) ;
189+ break ;
196190 }
197191 }
198192
199- return var.GetValueFromState (var.num_states - 1 );
193+ if (selected_state == -1 ) {
194+ throw std::runtime_error (" No state selected." );
195+ }
196+
197+ return var.GetValueFromState (selected_state);
200198 }
201199
202200 std::uniform_real_distribution<double > dist{0.0 , 1.0 };
0 commit comments