Skip to content

Commit 1cd1829

Browse files
committed
Generalize Partition class, add unit tests
1 parent d718863 commit 1cd1829

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

tests/partition_tests.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#include "doctest.h"
2+
#include "util/partition.hh"
3+
#include <array>
4+
5+
TEST_CASE("Basic usage") {
6+
7+
Partition<2, 8> parts;
8+
9+
auto nums = std::array{1u, 9u, 3u, 5u, 2u, 10u};
10+
parts.calculate(nums);
11+
12+
CHECK(parts.parts[0].size() == 3);
13+
CHECK(nums[parts.parts[0][0]] == 10);
14+
CHECK(nums[parts.parts[0][1]] == 3);
15+
CHECK(nums[parts.parts[0][2]] == 2);
16+
17+
CHECK(parts.parts[1].size() == 3);
18+
CHECK(nums[parts.parts[1][0]] == 9);
19+
CHECK(nums[parts.parts[1][1]] == 5);
20+
CHECK(nums[parts.parts[1][2]] == 1);
21+
}

util/partition.hh

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,25 @@
22
#include "util/fixed_vector.hh"
33
#include <span>
44

5-
template<unsigned MaxElements>
6-
struct PartitionTwo {
5+
// Splits a list of elements into N groups, the sum of each group being roughly equal.
6+
// The **indices** of the original elements are stored resulting `parts`
7+
template<unsigned NumPartitions, unsigned MaxElements, typename T = unsigned>
8+
struct Partition {
9+
// parts[N] contains the indices of the elements whose sums are roughly equal to other parts[]
10+
std::array<FixedVector<unsigned, MaxElements>, NumPartitions> parts{};
711

8-
FixedVector<unsigned, MaxElements> a{};
9-
FixedVector<unsigned, MaxElements> b{};
12+
Partition() = default;
13+
14+
Partition(std::span<T> vals) {
15+
calculate(vals);
16+
}
17+
18+
void calculate(std::span<T> vals) {
19+
for (auto &part : parts)
20+
part.clear();
1021

11-
PartitionTwo(std::span<unsigned> vals) {
1222
struct IdVal {
13-
unsigned val;
23+
T val;
1424
unsigned id;
1525
};
1626

@@ -21,18 +31,14 @@ struct PartitionTwo {
2131

2232
std::ranges::sort(ordered, std::greater{}, &IdVal::val);
2333

24-
unsigned a_sum = 0;
25-
unsigned b_sum = 0;
34+
std::array<T, NumPartitions> sums{};
2635

27-
// The part with the smaller sum gets the next element
36+
// The part with the smallest sum gets the next element
2837
for (auto [val, i] : ordered) {
29-
if (a_sum <= b_sum) {
30-
a_sum += val;
31-
a.push_back(i);
32-
} else {
33-
b_sum += val;
34-
b.push_back(i);
35-
}
38+
auto min_sum = std::ranges::min_element(sums);
39+
auto idx = std::ranges::distance(sums.begin(), min_sum);
40+
*min_sum += val;
41+
parts[idx].push_back(i);
3642
}
3743
}
3844
};

0 commit comments

Comments
 (0)