|
15 | 15 | #pragma once |
16 | 16 |
|
17 | 17 | #include "../utility/variable.hpp" |
| 18 | +#include "../utility/min_polynomial.hpp" |
18 | 19 | #include "./sa_system.hpp" |
19 | 20 | #include <cstdint> |
20 | 21 | #include <random> |
@@ -186,64 +187,69 @@ class IntegerSASystem<graph::IntegerPolynomialModel, RandType> { |
186 | 187 | } |
187 | 188 |
|
188 | 189 | std::pair<int, double> GetMinEnergyDifference(std::int64_t index) { |
189 | | - if (this->UnderQuadraticCoeff(index)) { |
190 | | - const auto &x = this->state_[index]; |
191 | | - const std::int64_t dxl = x.lower_bound - x.value; |
192 | | - const std::int64_t dxu = x.upper_bound - x.value; |
| 190 | + const auto &x = this->state_[index]; |
| 191 | + const std::int64_t dxl = x.lower_bound - x.value; |
| 192 | + const std::int64_t dxu = x.upper_bound - x.value; |
193 | 193 |
|
| 194 | + if (this->UnderQuadraticCoeff(index)) { |
194 | 195 | auto it_a = this->base_energy_difference_[index].find(2); |
195 | 196 | const double a = (it_a != this->base_energy_difference_[index].end()) |
196 | 197 | ? it_a->second |
197 | 198 | : 0.0; |
198 | | - |
199 | 199 | auto it_b = this->base_energy_difference_[index].find(1); |
200 | | - const double b_base = (it_b != this->base_energy_difference_[index].end()) |
| 200 | + const double b = (it_b != this->base_energy_difference_[index].end()) |
201 | 201 | ? it_b->second |
202 | 202 | : 0.0; |
203 | | - const double b = b_base + 2 * x.value * a; |
204 | | - |
205 | | - if (a > 0) { |
206 | | - const double center = -b / (2 * a); |
207 | | - if (dxu <= center) { |
208 | | - return {x.upper_bound, |
209 | | - this->GetEnergyDifference(index, x.upper_bound)}; |
210 | | - } else if (dxl < center && center < dxu) { |
211 | | - const std::int64_t dx_left = |
212 | | - static_cast<std::int64_t>(std::floor(center)); |
213 | | - const std::int64_t dx_right = |
214 | | - static_cast<std::int64_t>(std::ceil(center)); |
215 | | - if (center - dx_left <= dx_right - center) { |
216 | | - return {x.value + dx_left, |
217 | | - this->GetEnergyDifference(index, x.value + dx_left)}; |
218 | | - } else { |
219 | | - return {x.value + dx_right, |
220 | | - this->GetEnergyDifference(index, x.value + dx_right)}; |
221 | | - } |
222 | | - } else if (dxl >= center) { |
223 | | - return {x.lower_bound, |
224 | | - this->GetEnergyDifference(index, x.lower_bound)}; |
225 | | - } else { |
226 | | - throw std::runtime_error("Invalid state in GetMinEnergyDifference"); |
227 | | - } |
228 | | - } else if (a == 0) { |
229 | | - if (b > 0) { |
230 | | - return {x.lower_bound, |
231 | | - this->GetEnergyDifference(index, x.lower_bound)}; |
232 | | - } else if (b < 0) { |
233 | | - return {x.upper_bound, |
234 | | - this->GetEnergyDifference(index, x.upper_bound)}; |
235 | | - } else { |
236 | | - return {x.GenerateRandomValue(this->random_number_engine), 0.0}; |
237 | | - } |
238 | | - } else { // a < 0 |
239 | | - const double dE_lower = this->GetEnergyDifference(index, x.lower_bound); |
240 | | - const double dE_upper = this->GetEnergyDifference(index, x.upper_bound); |
241 | | - if (dE_lower <= dE_upper) { |
242 | | - return {x.lower_bound, dE_lower}; |
243 | | - } else { |
244 | | - return {x.upper_bound, dE_upper}; |
245 | | - } |
246 | | - } |
| 203 | + |
| 204 | + const double aa = a; |
| 205 | + const double bb = b + 2 * x.value * a; |
| 206 | + |
| 207 | + return utility::FindMinimumIntegerQuadratic(aa, bb, dxl, dxu, x.value, this->random_number_engine); |
| 208 | + } |
| 209 | + else if (this->IsCubicCoeff(index)) { |
| 210 | + auto it_a = this->base_energy_difference_[index].find(3); |
| 211 | + const double a = (it_a != this->base_energy_difference_[index].end()) |
| 212 | + ? it_a->second |
| 213 | + : 0.0; |
| 214 | + auto it_b = this->base_energy_difference_[index].find(2); |
| 215 | + const double b = (it_b != this->base_energy_difference_[index].end()) |
| 216 | + ? it_b->second |
| 217 | + : 0.0; |
| 218 | + auto it_c = this->base_energy_difference_[index].find(1); |
| 219 | + const double c = (it_c != this->base_energy_difference_[index].end()) |
| 220 | + ? it_c->second |
| 221 | + : 0.0; |
| 222 | + |
| 223 | + const double aa = a; |
| 224 | + const double bb = 3 * a * x.value + b; |
| 225 | + const double cc = 3 * a * x.value * x.value + 2 * b * x.value + c; |
| 226 | + |
| 227 | + return utility::FindMinimumIntegerCubic(aa, bb, cc, dxl, dxu, x.value, this->random_number_engine); |
| 228 | + } else if (this->IsQuarticCoeff(index)) { |
| 229 | + auto it_a = this->base_energy_difference_[index].find(4); |
| 230 | + const double a = (it_a != this->base_energy_difference_[index].end()) |
| 231 | + ? it_a->second |
| 232 | + : 0.0; |
| 233 | + auto it_b = this->base_energy_difference_[index].find(3); |
| 234 | + const double b = (it_b != this->base_energy_difference_[index].end()) |
| 235 | + ? it_b->second |
| 236 | + : 0.0; |
| 237 | + auto it_c = this->base_energy_difference_[index].find(2); |
| 238 | + const double c = (it_c != this->base_energy_difference_[index].end()) |
| 239 | + ? it_c->second |
| 240 | + : 0.0; |
| 241 | + auto it_d = this->base_energy_difference_[index].find(1); |
| 242 | + const double d = (it_d != this->base_energy_difference_[index].end()) |
| 243 | + ? it_d->second |
| 244 | + : 0.0; |
| 245 | + |
| 246 | + const double aa = a; |
| 247 | + const double bb = 4 * a * x.value + b; |
| 248 | + const double cc = 6 * a * x.value * x.value + 3 * b * x.value + c; |
| 249 | + const double dd = 4 * a * x.value * x.value * x.value + |
| 250 | + 3 * b * x.value * x.value + 2 * c * x.value + d; |
| 251 | + |
| 252 | + return utility::FindMinimumIntegerQuartic(aa, bb, cc, dd, dxl, dxu, x.value, this->random_number_engine); |
247 | 253 | } else { |
248 | 254 | double min_dE = std::numeric_limits<double>::infinity(); |
249 | 255 | std::int64_t min_value = -1; |
@@ -281,12 +287,24 @@ class IntegerSASystem<graph::IntegerPolynomialModel, RandType> { |
281 | 287 |
|
282 | 288 | double GetEnergy() const { return this->energy_; } |
283 | 289 |
|
284 | | - bool OnlyMultiLinearCoeff(std::int64_t index) const { |
285 | | - return this->model.GetOnlyMultilinearIndexSet().count(index) > 0; |
| 290 | + bool IsLinearCoeff(std::int64_t index) const { |
| 291 | + return this->model.GetEachVariableDegreeAt(index) == 1; |
286 | 292 | } |
287 | 293 |
|
288 | 294 | bool UnderQuadraticCoeff(std::int64_t index) const { |
289 | | - return this->model.GetUnderQuadraticIndexSet().count(index) > 0; |
| 295 | + return (this->model.GetEachVariableDegreeAt(index) == 2) || (this->model.GetEachVariableDegreeAt(index) == 1); |
| 296 | + } |
| 297 | + |
| 298 | + bool IsCubicCoeff(std::int64_t index) const { |
| 299 | + return this->model.GetEachVariableDegreeAt(index) == 3; |
| 300 | + } |
| 301 | + |
| 302 | + bool IsQuarticCoeff(std::int64_t index) const { |
| 303 | + return this->model.GetEachVariableDegreeAt(index) == 4; |
| 304 | + } |
| 305 | + |
| 306 | + bool CanOptMove(std::int64_t index) const { |
| 307 | + return this->UnderQuadraticCoeff(index) || this->IsCubicCoeff(index) || this->IsQuarticCoeff(index); |
290 | 308 | } |
291 | 309 |
|
292 | 310 | public: |
|
0 commit comments