22#include " util/fixed_vector.hh"
33#include < span>
44
5- template <unsigned MaxElements>
6- struct PartitionTwo {
5+ // Splits a list of elements into N groups, the sum of each group being roughly equal.
6+ // The **indices** of the original elements are stored resulting `parts`
7+ template <unsigned NumPartitions, unsigned MaxElements, typename T = unsigned >
8+ struct Partition {
9+ // parts[N] contains the indices of the elements whose sums are roughly equal to other parts[]
10+ std::array<FixedVector<unsigned , MaxElements>, NumPartitions> parts{};
711
8- FixedVector<unsigned , MaxElements> a{};
9- FixedVector<unsigned , MaxElements> b{};
12+ Partition () = default ;
13+
14+ Partition (std::span<T> vals) {
15+ calculate (vals);
16+ }
17+
18+ void calculate (std::span<T> vals) {
19+ for (auto &part : parts)
20+ part.clear ();
1021
11- PartitionTwo (std::span<unsigned > vals) {
1222 struct IdVal {
13- unsigned val;
23+ T val;
1424 unsigned id;
1525 };
1626
@@ -21,18 +31,14 @@ struct PartitionTwo {
2131
2232 std::ranges::sort (ordered, std::greater{}, &IdVal::val);
2333
24- unsigned a_sum = 0 ;
25- unsigned b_sum = 0 ;
34+ std::array<T, NumPartitions> sums{};
2635
27- // The part with the smaller sum gets the next element
36+ // The part with the smallest sum gets the next element
2837 for (auto [val, i] : ordered) {
29- if (a_sum <= b_sum) {
30- a_sum += val;
31- a.push_back (i);
32- } else {
33- b_sum += val;
34- b.push_back (i);
35- }
38+ auto min_sum = std::ranges::min_element (sums);
39+ auto idx = std::ranges::distance (sums.begin (), min_sum);
40+ *min_sum += val;
41+ parts[idx].push_back (i);
3642 }
3743 }
3844};
0 commit comments