Skip to content

Commit c8cf8f3

Browse files
OPT_METROPOLISを4次項まで対応 (#502)
# 変更 `sample_huio`関数において、`updater=OPT_METROPOLIS`とすると、ある変数に対して最もエネルギーが下がるような最適遷移が行われます。 現在のmainの実装ではこの最適遷移が行われるのはz^2のような2乗項だけでしたが、これを4乗項まで行うように変更しました。 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent b773dfd commit c8cf8f3

File tree

10 files changed

+425
-99
lines changed

10 files changed

+425
-99
lines changed

include/openjij/graph/integer_polynomial_model.hpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,19 @@ class IntegerPolynomialModel {
8383
std::sort(interactions.begin(), interactions.end());
8484
}
8585

86-
// Create only_multilinear_index_set and under_quadratic_index_set
86+
// Create max degree of each variable
87+
this->max_variable_degree_.resize(this->num_variables_, 0);
8788
for (std::int64_t index = 0; index < this->num_variables_; ++index) {
88-
bool is_multilinear = true;
89+
std::int64_t max_degree = 0;
8990
for (const auto &[_, degree] : this->index_to_interactions_[index]) {
90-
if (degree != 1) {
91-
is_multilinear = false;
92-
break;
91+
if (degree > max_degree) {
92+
max_degree = degree;
9393
}
9494
}
95-
if (is_multilinear) {
96-
this->only_multilinear_index_set_.insert(index);
97-
}
98-
99-
bool is_under_quadratic = true;
100-
for (const auto &[_, degree] : this->index_to_interactions_[index]) {
101-
if (degree > 2) {
102-
is_under_quadratic = false;
103-
break;
104-
}
105-
}
106-
if (is_under_quadratic) {
107-
this->under_quadratic_index_set_.insert(index);
95+
if (max_degree == 0) {
96+
throw std::runtime_error("Variable with no interactions found.");
10897
}
98+
this->max_variable_degree_[index] = max_degree;
10999
}
110100
}
111101

@@ -182,11 +172,11 @@ class IntegerPolynomialModel {
182172
GetIndexToInteractions() const {
183173
return this->index_to_interactions_;
184174
}
185-
const std::unordered_set<std::int64_t> &GetOnlyMultilinearIndexSet() const {
186-
return this->only_multilinear_index_set_;
175+
std::int64_t GetEachVariableDegreeAt(const std::int64_t index) const {
176+
return this->max_variable_degree_[index];
187177
}
188-
const std::unordered_set<std::int64_t> &GetUnderQuadraticIndexSet() const {
189-
return this->under_quadratic_index_set_;
178+
const std::vector<std::int64_t> &GetEachVariableDegree() const {
179+
return this->max_variable_degree_;
190180
}
191181

192182
private:
@@ -199,8 +189,7 @@ class IntegerPolynomialModel {
199189
key_value_list_;
200190
std::vector<std::vector<std::pair<std::size_t, std::int64_t>>>
201191
index_to_interactions_;
202-
std::unordered_set<std::int64_t> only_multilinear_index_set_;
203-
std::unordered_set<std::int64_t> under_quadratic_index_set_;
192+
std::vector<std::int64_t> max_variable_degree_;
204193
};
205194

206195
} // namespace graph

include/openjij/system/integer_polynomial_sa_system.hpp

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616

1717
#include "../utility/variable.hpp"
18+
#include "../utility/min_polynomial.hpp"
1819
#include "./sa_system.hpp"
1920
#include <cstdint>
2021
#include <random>
@@ -186,64 +187,69 @@ class IntegerSASystem<graph::IntegerPolynomialModel, RandType> {
186187
}
187188

188189
std::pair<int, double> GetMinEnergyDifference(std::int64_t index) {
189-
if (this->UnderQuadraticCoeff(index)) {
190-
const auto &x = this->state_[index];
191-
const std::int64_t dxl = x.lower_bound - x.value;
192-
const std::int64_t dxu = x.upper_bound - x.value;
190+
const auto &x = this->state_[index];
191+
const std::int64_t dxl = x.lower_bound - x.value;
192+
const std::int64_t dxu = x.upper_bound - x.value;
193193

194+
if (this->UnderQuadraticCoeff(index)) {
194195
auto it_a = this->base_energy_difference_[index].find(2);
195196
const double a = (it_a != this->base_energy_difference_[index].end())
196197
? it_a->second
197198
: 0.0;
198-
199199
auto it_b = this->base_energy_difference_[index].find(1);
200-
const double b_base = (it_b != this->base_energy_difference_[index].end())
200+
const double b = (it_b != this->base_energy_difference_[index].end())
201201
? it_b->second
202202
: 0.0;
203-
const double b = b_base + 2 * x.value * a;
204-
205-
if (a > 0) {
206-
const double center = -b / (2 * a);
207-
if (dxu <= center) {
208-
return {x.upper_bound,
209-
this->GetEnergyDifference(index, x.upper_bound)};
210-
} else if (dxl < center && center < dxu) {
211-
const std::int64_t dx_left =
212-
static_cast<std::int64_t>(std::floor(center));
213-
const std::int64_t dx_right =
214-
static_cast<std::int64_t>(std::ceil(center));
215-
if (center - dx_left <= dx_right - center) {
216-
return {x.value + dx_left,
217-
this->GetEnergyDifference(index, x.value + dx_left)};
218-
} else {
219-
return {x.value + dx_right,
220-
this->GetEnergyDifference(index, x.value + dx_right)};
221-
}
222-
} else if (dxl >= center) {
223-
return {x.lower_bound,
224-
this->GetEnergyDifference(index, x.lower_bound)};
225-
} else {
226-
throw std::runtime_error("Invalid state in GetMinEnergyDifference");
227-
}
228-
} else if (a == 0) {
229-
if (b > 0) {
230-
return {x.lower_bound,
231-
this->GetEnergyDifference(index, x.lower_bound)};
232-
} else if (b < 0) {
233-
return {x.upper_bound,
234-
this->GetEnergyDifference(index, x.upper_bound)};
235-
} else {
236-
return {x.GenerateRandomValue(this->random_number_engine), 0.0};
237-
}
238-
} else { // a < 0
239-
const double dE_lower = this->GetEnergyDifference(index, x.lower_bound);
240-
const double dE_upper = this->GetEnergyDifference(index, x.upper_bound);
241-
if (dE_lower <= dE_upper) {
242-
return {x.lower_bound, dE_lower};
243-
} else {
244-
return {x.upper_bound, dE_upper};
245-
}
246-
}
203+
204+
const double aa = a;
205+
const double bb = b + 2 * x.value * a;
206+
207+
return utility::FindMinimumIntegerQuadratic(aa, bb, dxl, dxu, x.value, this->random_number_engine);
208+
}
209+
else if (this->IsCubicCoeff(index)) {
210+
auto it_a = this->base_energy_difference_[index].find(3);
211+
const double a = (it_a != this->base_energy_difference_[index].end())
212+
? it_a->second
213+
: 0.0;
214+
auto it_b = this->base_energy_difference_[index].find(2);
215+
const double b = (it_b != this->base_energy_difference_[index].end())
216+
? it_b->second
217+
: 0.0;
218+
auto it_c = this->base_energy_difference_[index].find(1);
219+
const double c = (it_c != this->base_energy_difference_[index].end())
220+
? it_c->second
221+
: 0.0;
222+
223+
const double aa = a;
224+
const double bb = 3 * a * x.value + b;
225+
const double cc = 3 * a * x.value * x.value + 2 * b * x.value + c;
226+
227+
return utility::FindMinimumIntegerCubic(aa, bb, cc, dxl, dxu, x.value, this->random_number_engine);
228+
} else if (this->IsQuarticCoeff(index)) {
229+
auto it_a = this->base_energy_difference_[index].find(4);
230+
const double a = (it_a != this->base_energy_difference_[index].end())
231+
? it_a->second
232+
: 0.0;
233+
auto it_b = this->base_energy_difference_[index].find(3);
234+
const double b = (it_b != this->base_energy_difference_[index].end())
235+
? it_b->second
236+
: 0.0;
237+
auto it_c = this->base_energy_difference_[index].find(2);
238+
const double c = (it_c != this->base_energy_difference_[index].end())
239+
? it_c->second
240+
: 0.0;
241+
auto it_d = this->base_energy_difference_[index].find(1);
242+
const double d = (it_d != this->base_energy_difference_[index].end())
243+
? it_d->second
244+
: 0.0;
245+
246+
const double aa = a;
247+
const double bb = 4 * a * x.value + b;
248+
const double cc = 6 * a * x.value * x.value + 3 * b * x.value + c;
249+
const double dd = 4 * a * x.value * x.value * x.value +
250+
3 * b * x.value * x.value + 2 * c * x.value + d;
251+
252+
return utility::FindMinimumIntegerQuartic(aa, bb, cc, dd, dxl, dxu, x.value, this->random_number_engine);
247253
} else {
248254
double min_dE = std::numeric_limits<double>::infinity();
249255
std::int64_t min_value = -1;
@@ -281,12 +287,24 @@ class IntegerSASystem<graph::IntegerPolynomialModel, RandType> {
281287

282288
double GetEnergy() const { return this->energy_; }
283289

284-
bool OnlyMultiLinearCoeff(std::int64_t index) const {
285-
return this->model.GetOnlyMultilinearIndexSet().count(index) > 0;
290+
bool IsLinearCoeff(std::int64_t index) const {
291+
return this->model.GetEachVariableDegreeAt(index) == 1;
286292
}
287293

288294
bool UnderQuadraticCoeff(std::int64_t index) const {
289-
return this->model.GetUnderQuadraticIndexSet().count(index) > 0;
295+
return (this->model.GetEachVariableDegreeAt(index) == 2) || (this->model.GetEachVariableDegreeAt(index) == 1);
296+
}
297+
298+
bool IsCubicCoeff(std::int64_t index) const {
299+
return this->model.GetEachVariableDegreeAt(index) == 3;
300+
}
301+
302+
bool IsQuarticCoeff(std::int64_t index) const {
303+
return this->model.GetEachVariableDegreeAt(index) == 4;
304+
}
305+
306+
bool CanOptMove(std::int64_t index) const {
307+
return this->UnderQuadraticCoeff(index) || this->IsCubicCoeff(index) || this->IsQuarticCoeff(index);
290308
}
291309

292310
public:

include/openjij/system/integer_quadratic_sa_system.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,11 @@ class IntegerSASystem<graph::IntegerQuadraticModel, RandType> {
173173

174174
double GetEnergy() const { return this->energy_; }
175175

176-
bool OnlyMultiLinearCoeff(std::int64_t index) const {
176+
bool IsLinearCoeff(std::int64_t index) const {
177177
return this->model.GetOnlyBilinearIndexSet().count(index) > 0;
178178
}
179179

180-
bool UnderQuadraticCoeff(std::int64_t index) const { return true; }
180+
bool CanOptMove(std::int64_t index) const { return true; }
181181

182182
double GetLinearCoeff(std::int64_t index) const {
183183
return this->linear_coeff_[index];

include/openjij/updater/single_integer_move.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct OptMetropolisUpdater {
3939
std::int64_t GenerateNewValue(SystemType &sa_system, const std::int64_t index,
4040
const double T, const double progress) {
4141
// Metropolis Optimal Transition if possible
42-
// This is used for systems with only quadratic coefficientsa
43-
if (sa_system.UnderQuadraticCoeff(index) && dist(sa_system.random_number_engine) < progress) {
42+
// This is used for systems with up to 4th power coefficients
43+
if (sa_system.CanOptMove(index) && dist(sa_system.random_number_engine) < progress) {
4444
const auto [min_val, min_dE] = sa_system.GetMinEnergyDifference(index);
4545
if (min_dE <= 0.0 ||
4646
dist(sa_system.random_number_engine) < std::exp(-min_dE / T)) {
@@ -67,7 +67,7 @@ struct HeatBathUpdater {
6767
template <typename SystemType>
6868
std::int64_t GenerateNewValue(SystemType &sa_system, const std::int64_t index,
6969
const double T, const double _progress) {
70-
if (sa_system.OnlyMultiLinearCoeff(index)) {
70+
if (sa_system.IsLinearCoeff(index)) {
7171
return ForBilinear(sa_system, index, T, _progress);
7272
} else {
7373
return ForAll(sa_system, index, T, _progress);

0 commit comments

Comments
 (0)