diff --git a/libc/src/__support/CMakeLists.txt b/libc/src/__support/CMakeLists.txt index 2196d9e23bba7..7103bf7f5fc21 100644 --- a/libc/src/__support/CMakeLists.txt +++ b/libc/src/__support/CMakeLists.txt @@ -398,6 +398,31 @@ add_header_library( libc.src.__support.macros.attributes ) +add_header_library( + aba_ptr + HDRS + aba_ptr.h + DEPENDS + libc.hdr.types.size_t + libc.src.__support.common + libc.src.__support.CPP.atomic + libc.src.__support.threads.sleep +) + +add_header_library( + mpmc_stack + HDRS + mpmc_stack.h + DEPENDS + libc.src.__support.aba_ptr + libc.src.__support.common + libc.src.__support.CPP.atomic + libc.src.__support.CPP.new + libc.src.__support.CPP.optional + libc.src.__support.CPP.type_traits +) + + add_subdirectory(FPUtil) add_subdirectory(OSUtil) add_subdirectory(StringUtil) diff --git a/libc/src/__support/aba_ptr.h b/libc/src/__support/aba_ptr.h new file mode 100644 index 0000000000000..0da752daeb4ce --- /dev/null +++ b/libc/src/__support/aba_ptr.h @@ -0,0 +1,89 @@ +//===-- Transactional Ptr for ABA prevention --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC___SUPPORT_TAGGED_POINTER_H +#define LLVM_LIBC_SRC___SUPPORT_TAGGED_POINTER_H + +#include "hdr/types/size_t.h" +#include "src/__support/CPP/atomic.h" +#include "src/__support/common.h" +#include "src/__support/threads/sleep.h" + +#ifdef __GCC_HAVE_SYNC_COMPARE_AND_SWAP_16 +#define LIBC_ABA_PTR_IS_ATOMIC true +#else +#define LIBC_ABA_PTR_IS_ATOMIC false +#endif + +namespace LIBC_NAMESPACE_DECL { + +template struct AbaPtrImpl { + union Impl { + struct alignas(2 * alignof(void *)) Atomic { + T *ptr; + size_t tag; + } atomic; + struct Mutex { + T *ptr; + bool locked; + } mutex; + } impl; + + LIBC_INLINE constexpr AbaPtrImpl(T *ptr) + : impl(IsAtomic ? Impl{.atomic{ptr, 0}} : Impl{.mutex{ptr, false}}) {} + + /// User must guarantee that operation is redoable. + template LIBC_INLINE void transaction(Op &&op) { + if constexpr (IsAtomic) { + for (;;) { + cpp::AtomicRef ref(impl.atomic); + typename Impl::Atomic snapshot, next; + snapshot = ref.load(cpp::MemoryOrder::RELAXED); + next.ptr = op(snapshot.ptr); + // Wrapping add for unsigned integers. + next.tag = snapshot.tag + 1; + // Redo transaction can be costly, so we use strong version. + if (ref.compare_exchange_strong(snapshot, next, + cpp::MemoryOrder::ACQ_REL, + cpp::MemoryOrder::RELAXED)) + return; + } + } else { + // Acquire the lock. + cpp::AtomicRef ref(impl.mutex.locked); + while (ref.exchange(true, cpp::MemoryOrder::ACQUIRE)) + while (ref.load(cpp::MemoryOrder::RELAXED)) + LIBC_NAMESPACE::sleep_briefly(); + + impl.mutex.ptr = op(impl.mutex.ptr); + // Release the lock. + ref.store(false, cpp::MemoryOrder::RELEASE); + } + } + + LIBC_INLINE T *get() const { + if constexpr (IsAtomic) { + // Weak micro-architectures typically regards simultaneous partial word + // loading and full word loading as a race condition. While there are + // implementations that uses racy read anyway, we still load the whole + // word to avoid any complications. + typename Impl::Atomic snapshot; + cpp::AtomicRef ref(impl.atomic); + snapshot = ref.load(cpp::MemoryOrder::RELAXED); + return snapshot.ptr; + } else { + return impl.mutex.ptr; + } + } +}; + +template using AbaPtr = AbaPtrImpl; +} // namespace LIBC_NAMESPACE_DECL + +#undef LIBC_ABA_PTR_IS_ATOMIC +#endif diff --git a/libc/src/__support/mpmc_stack.h b/libc/src/__support/mpmc_stack.h new file mode 100644 index 0000000000000..259732ef84568 --- /dev/null +++ b/libc/src/__support/mpmc_stack.h @@ -0,0 +1,116 @@ +//===-- Simple Lock-free MPMC Stack -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC___SUPPORT_MPMC_STACK_H +#define LLVM_LIBC_SRC___SUPPORT_MPMC_STACK_H + +#include "src/__support/CPP/atomic.h" +#include "src/__support/CPP/new.h" +#include "src/__support/CPP/optional.h" +#include "src/__support/aba_ptr.h" + +namespace LIBC_NAMESPACE_DECL { +template class MPMCStack { + struct Node { + cpp::Atomic visitor; + Node *next; + T value; + + LIBC_INLINE Node(T val) : visitor(0), next(nullptr), value(val) {} + }; + AbaPtr head; + +public: + static_assert(cpp::is_copy_constructible::value, + "T must be copy constructible"); + LIBC_INLINE constexpr MPMCStack() : head(nullptr) {} + LIBC_INLINE bool push(T value) { + AllocChecker ac; + Node *new_node = new (ac) Node(value); + if (!ac) + return false; + head.transaction([new_node](Node *old_head) { + new_node->next = old_head; + return new_node; + }); + return true; + } + LIBC_INLINE bool push_all(T values[], size_t count) { + struct Guard { + Node *cursor; + LIBC_INLINE Guard() : cursor(nullptr) {} + LIBC_INLINE ~Guard() { + while (cursor) { + Node *next = cursor->next; + delete cursor; + cursor = next; + } + } + LIBC_INLINE void advance(Node *node) { + node->next = cursor; + cursor = node; + } + LIBC_INLINE Node *finish() { + Node *ret = cursor; + cursor = nullptr; + return ret; + } + }; + Node *first = nullptr; + Node *last = nullptr; + { + Guard guard{}; + for (size_t i = 0; i < count; ++i) { + AllocChecker ac; + Node *new_node = new (ac) Node(values[i]); + if (!ac) + return false; + if (i == 0) + first = new_node; + guard.advance(new_node); + } + last = guard.finish(); + } + head.transaction([first, last](Node *old_head) { + first->next = old_head; + return last; + }); + return true; + } + LIBC_INLINE cpp::optional pop() { + cpp::optional res = cpp::nullopt; + Node *node = nullptr; + head.transaction([&](Node *current_head) -> Node * { + if (!current_head) { + res = cpp::nullopt; + return nullptr; + } + node = current_head; + node->visitor.fetch_add(1, cpp::MemoryOrder::ACQUIRE); + res = cpp::optional{node->value}; + Node *next = node->next; + node->visitor.fetch_sub(1, cpp::MemoryOrder::RELEASE); + return next; + }); + // On a successful transaction, a node is popped by us. So we must delete + // it. When we are at here, no one else can acquire + // new reference to the node, but we still need to wait until other threads + // inside the transaction who may potentially be holding a reference to the + // node. + if (res) { + // Spin until the node is no longer in use. + while (node->visitor.load(cpp::MemoryOrder::RELAXED) != 0) + LIBC_NAMESPACE::sleep_briefly(); + delete node; + } + return res; + } +}; +} // namespace LIBC_NAMESPACE_DECL + +#endif diff --git a/libc/test/integration/src/__support/CMakeLists.txt b/libc/test/integration/src/__support/CMakeLists.txt index b5b6557e8d689..93f54083f3c00 100644 --- a/libc/test/integration/src/__support/CMakeLists.txt +++ b/libc/test/integration/src/__support/CMakeLists.txt @@ -2,3 +2,18 @@ add_subdirectory(threads) if(LIBC_TARGET_OS_IS_GPU) add_subdirectory(GPU) endif() + +add_libc_integration_test_suite(libc-support-integration-tests) + +add_integration_test( + mpmc_stack_test + SUITE + libc-support-integration-tests + SRCS + mpmc_stack_test.cpp + DEPENDS + libc.src.__support.mpmc_stack + libc.src.__support.threads.thread + libc.src.pthread.pthread_create + libc.src.pthread.pthread_join +) diff --git a/libc/test/integration/src/__support/mpmc_stack_test.cpp b/libc/test/integration/src/__support/mpmc_stack_test.cpp new file mode 100644 index 0000000000000..9166a816a74fe --- /dev/null +++ b/libc/test/integration/src/__support/mpmc_stack_test.cpp @@ -0,0 +1,119 @@ +#include "src/__support/CPP/atomic.h" +#include "src/__support/mpmc_stack.h" +#include "src/pthread/pthread_create.h" +#include "src/pthread/pthread_join.h" +#include "test/IntegrationTest/test.h" + +using namespace LIBC_NAMESPACE; + +void smoke_test() { + MPMCStack stack; + for (int i = 0; i <= 100; ++i) + if (!stack.push(i)) + __builtin_trap(); + for (int i = 100; i >= 0; --i) + if (*stack.pop() != i) + __builtin_trap(); + if (stack.pop()) + __builtin_trap(); // Should be empty now. +} + +void multithread_test() { + constexpr static size_t NUM_THREADS = 5; + constexpr static size_t NUM_PUSHES = 100; + struct State { + MPMCStack stack; + cpp::Atomic counter = 0; + cpp::Atomic flags[NUM_PUSHES]; + } state; + pthread_t threads[NUM_THREADS]; + for (size_t i = 0; i < NUM_THREADS; ++i) { + LIBC_NAMESPACE::pthread_create( + &threads[i], nullptr, + [](void *arg) -> void * { + State *state = static_cast(arg); + for (;;) { + size_t current = state->counter.fetch_add(1); + if (current >= NUM_PUSHES) + break; + if (!state->stack.push(current)) + __builtin_trap(); + } + while (auto res = state->stack.pop()) + state->flags[res.value()].store(true); + return nullptr; + }, + &state); + } + for (pthread_t thread : threads) + LIBC_NAMESPACE::pthread_join(thread, nullptr); + while (cpp::optional res = state.stack.pop()) + state.flags[res.value()].store(true); + for (size_t i = 0; i < NUM_PUSHES; ++i) + if (!state.flags[i].load()) + __builtin_trap(); +} + +void multithread_push_all_test() { + constexpr static size_t NUM_THREADS = 4; + constexpr static size_t BATCH_SIZE = 10; + constexpr static size_t NUM_BATCHES = 20; + struct State { + MPMCStack stack; + cpp::Atomic counter = 0; + cpp::Atomic flags[NUM_THREADS * BATCH_SIZE * NUM_BATCHES]; + } state; + pthread_t threads[NUM_THREADS]; + + for (size_t i = 0; i < NUM_THREADS; ++i) { + LIBC_NAMESPACE::pthread_create( + &threads[i], nullptr, + [](void *arg) -> void * { + State *state = static_cast(arg); + size_t values[BATCH_SIZE]; + + for (size_t batch = 0; batch < NUM_BATCHES; ++batch) { + // Prepare batch of values + for (size_t j = 0; j < BATCH_SIZE; ++j) { + size_t current = state->counter.fetch_add(1); + values[j] = current; + } + + // Push all values in batch + if (!state->stack.push_all(values, BATCH_SIZE)) + __builtin_trap(); + } + + // Pop and mark all values + while (auto res = state->stack.pop()) { + size_t value = res.value(); + if (value < NUM_THREADS * BATCH_SIZE * NUM_BATCHES) + state->flags[value].store(true); + } + return nullptr; + }, + &state); + } + + for (pthread_t thread : threads) + LIBC_NAMESPACE::pthread_join(thread, nullptr); + + // Pop any remaining values + while (cpp::optional res = state.stack.pop()) { + size_t value = res.value(); + if (value < NUM_THREADS * BATCH_SIZE * NUM_BATCHES) + state.flags[value].store(true); + } + + // Verify all values were processed + for (size_t i = 0; i < NUM_THREADS * BATCH_SIZE * NUM_BATCHES; ++i) + if (!state.flags[i].load()) + __builtin_trap(); +} + +TEST_MAIN() { + smoke_test(); + multithread_test(); + multithread_push_all_test(); + return 0; +}