Skip to content

Commit a8aaadf

Browse files
add push all API
1 parent 6941a6f commit a8aaadf

File tree

2 files changed

+92
-1
lines changed

2 files changed

+92
-1
lines changed

libc/src/__support/mpmc_stack.h

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ template <class T> class MPMCStack {
2929
public:
3030
static_assert(cpp::is_copy_constructible<T>::value,
3131
"T must be copy constructible");
32-
LIBC_INLINE MPMCStack() : head(nullptr) {}
32+
LIBC_INLINE constexpr MPMCStack() : head(nullptr) {}
3333
LIBC_INLINE bool push(T value) {
3434
AllocChecker ac;
3535
Node *new_node = new (ac) Node(value);
@@ -41,6 +41,39 @@ template <class T> class MPMCStack {
4141
});
4242
return true;
4343
}
44+
LIBC_INLINE bool push_all(T values[], size_t count) {
45+
struct Guard {
46+
size_t count;
47+
Node **allocated;
48+
LIBC_INLINE Guard(Node *allocated[]) : count(0), allocated(allocated) {}
49+
LIBC_INLINE ~Guard() {
50+
for (size_t i = 0; i < count; ++i)
51+
delete allocated[i];
52+
}
53+
LIBC_INLINE void add(Node *node) { allocated[count++] = node; }
54+
LIBC_INLINE void clear() { count = 0; }
55+
};
56+
// Variable sized array is a GNU extension.
57+
__extension__ Node *allocated[count];
58+
{
59+
Guard guard(allocated);
60+
for (size_t i = 0; i < count; ++i) {
61+
AllocChecker ac;
62+
Node *new_node = new (ac) Node(values[i]);
63+
if (!ac)
64+
return false;
65+
guard.add(new_node);
66+
if (i != 0)
67+
new_node->next = allocated[i - 1];
68+
}
69+
guard.clear();
70+
}
71+
head.transaction([&allocated, count](Node *old_head) {
72+
allocated[0]->next = old_head;
73+
return allocated[count - 1];
74+
});
75+
return true;
76+
}
4477
LIBC_INLINE cpp::optional<T> pop() {
4578
cpp::optional<T> res = cpp::nullopt;
4679
Node *node = nullptr;

libc/test/integration/src/__support/mpmc_stack_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,66 @@ void multithread_test() {
5454
__builtin_trap();
5555
}
5656

57+
void multithread_push_all_test() {
58+
constexpr static size_t NUM_THREADS = 4;
59+
constexpr static size_t BATCH_SIZE = 10;
60+
constexpr static size_t NUM_BATCHES = 20;
61+
struct State {
62+
MPMCStack<size_t> stack;
63+
cpp::Atomic<size_t> counter = 0;
64+
cpp::Atomic<bool> flags[NUM_THREADS * BATCH_SIZE * NUM_BATCHES];
65+
} state;
66+
pthread_t threads[NUM_THREADS];
67+
68+
for (size_t i = 0; i < NUM_THREADS; ++i) {
69+
LIBC_NAMESPACE::pthread_create(
70+
&threads[i], nullptr,
71+
[](void *arg) -> void * {
72+
State *state = static_cast<State *>(arg);
73+
size_t values[BATCH_SIZE];
74+
75+
for (size_t batch = 0; batch < NUM_BATCHES; ++batch) {
76+
// Prepare batch of values
77+
for (size_t j = 0; j < BATCH_SIZE; ++j) {
78+
size_t current = state->counter.fetch_add(1);
79+
values[j] = current;
80+
}
81+
82+
// Push all values in batch
83+
if (!state->stack.push_all(values, BATCH_SIZE))
84+
__builtin_trap();
85+
}
86+
87+
// Pop and mark all values
88+
while (auto res = state->stack.pop()) {
89+
size_t value = res.value();
90+
if (value < NUM_THREADS * BATCH_SIZE * NUM_BATCHES)
91+
state->flags[value].store(true);
92+
}
93+
return nullptr;
94+
},
95+
&state);
96+
}
97+
98+
for (pthread_t thread : threads)
99+
LIBC_NAMESPACE::pthread_join(thread, nullptr);
100+
101+
// Pop any remaining values
102+
while (cpp::optional<size_t> res = state.stack.pop()) {
103+
size_t value = res.value();
104+
if (value < NUM_THREADS * BATCH_SIZE * NUM_BATCHES)
105+
state.flags[value].store(true);
106+
}
107+
108+
// Verify all values were processed
109+
for (size_t i = 0; i < NUM_THREADS * BATCH_SIZE * NUM_BATCHES; ++i)
110+
if (!state.flags[i].load())
111+
__builtin_trap();
112+
}
113+
57114
TEST_MAIN() {
58115
smoke_test();
59116
multithread_test();
117+
multithread_push_all_test();
60118
return 0;
61119
}

0 commit comments

Comments
 (0)