@@ -77,36 +77,24 @@ struct HeatBathUpdater {
7777 template <typename SystemType>
7878 std::int64_t ForAll (SystemType &sa_system, const std::int64_t index,
7979 const double T, const double _progress) {
80-
81- const auto [max_weight_state_value, min_dE] = sa_system.GetMinEnergyDifference (index);
8280 const auto &var = sa_system.GetState ()[index];
8381 const double beta = 1.0 / T;
84- double z = 0.0 ;
85-
86- // Calculate the partition function
87- for (std::int64_t i = 0 ; i < var.num_states ; ++i) {
88- const double dE = sa_system.GetEnergyDifference (index, var.GetValueFromState (i)) - min_dE;
89- z += std::exp (-beta * dE);
90- }
91-
92- // Select a state based on the partition function
9382 std::int64_t selected_state_number = -1 ;
94- double cumulative_prob = 0.0 ;
95- const double rand = dist (sa_system.random_number_engine ) * z;
83+ double max_z = -std::numeric_limits<double >::infinity ();
9684
9785 for (std::int64_t i = 0 ; i < var.num_states ; ++i) {
98- const double dE = sa_system.GetEnergyDifference (index, var.GetValueFromState (i)) - min_dE;
99- cumulative_prob += std::exp (-beta * dE);
100- if (rand <= cumulative_prob) {
86+ const double g =
87+ -std::log (-std::log (dist (sa_system.random_number_engine )));
88+ const double z = -beta * sa_system.GetEnergyDifference (
89+ index, var.GetValueFromState (i)) + g;
90+ if (z > max_z) {
91+ max_z = z;
10192 selected_state_number = i;
102- break ;
10393 }
10494 }
105-
10695 if (selected_state_number == -1 ) {
10796 throw std::runtime_error (" No state selected." );
10897 }
109-
11098 return var.GetValueFromState (selected_state_number);
11199 }
112100
0 commit comments