Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions include/pybind11/critical_section.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,33 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class scoped_critical_section {
public:
#ifdef Py_GIL_DISABLED
explicit scoped_critical_section(handle obj) : has2(false) {
PyCriticalSection_Begin(&section, obj.ptr());
scoped_critical_section(handle obj1, handle obj2) {
if (obj1) {
if (obj2) {
PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
rank = 2;
} else {
PyCriticalSection_Begin(&section, obj1.ptr());
rank = 1;
}
} else if (obj2) {
PyCriticalSection_Begin(&section, obj2.ptr());
rank = 1;
}
}

scoped_critical_section(handle obj1, handle obj2) : has2(true) {
PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
explicit scoped_critical_section(handle obj) {
if (obj) {
PyCriticalSection_Begin(&section, obj.ptr());
rank = 1;
}
}

~scoped_critical_section() {
if (has2) {
PyCriticalSection2_End(&section2);
} else {
if (rank == 1) {
PyCriticalSection_End(&section);
} else if (rank == 2) {
PyCriticalSection2_End(&section2);
}
}
#else
Expand All @@ -39,7 +53,7 @@ class scoped_critical_section {

private:
#ifdef Py_GIL_DISABLED
bool has2;
int rank{0};
union {
PyCriticalSection section;
PyCriticalSection2 section2;
Expand Down
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ set(PYBIND11_TEST_FILES
test_potentially_slicing_weak_ptr
test_python_multiple_inheritance
test_pytypes
test_scoped_critical_section
test_sequences_and_iterators
test_smart_ptr
test_stl
Expand Down Expand Up @@ -566,7 +567,7 @@ set(PYBIND11_PYTEST_ARGS
# A single command to compile and run the tests
add_custom_target(
pytest
COMMAND ${PYBIND11_TEST_PREFIX_COMMAND} ${PYTHON_EXECUTABLE} -m pytest
COMMAND ${PYBIND11_TEST_PREFIX_COMMAND} ${PYTHON_EXECUTABLE} -X dev -X faulthandler -m pytest
${PYBIND11_ABS_PYTEST_FILES} ${PYBIND11_PYTEST_ARGS}
DEPENDS ${test_targets}
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
Expand Down
21 changes: 19 additions & 2 deletions tests/test_methods_and_attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import gc
import sys

import pytest
Expand All @@ -19,6 +20,12 @@
)


def gc_collect(repeat=5):
"""Collect garbage multiple times to ensure that objects are actually deleted."""
for _ in range(repeat):
gc.collect()


def test_self_only_pos_only():
assert (
m.ExampleMandA.__str__.__doc__
Expand Down Expand Up @@ -309,7 +316,7 @@ def test_property_rvalue_policy():
sys.version_info == (3, 14, 0, "beta", 1)
or sys.version_info == (3, 14, 0, "beta", 2),
reason="3.14.0b1/2 bug: https://github.com/python/cpython/issues/133912",
strict=True,
strict=False,
)
def test_dynamic_attributes():
instance = m.DynamicClass()
Expand Down Expand Up @@ -337,24 +344,32 @@ def test_dynamic_attributes():
cstats = ConstructorStats.get(m.DynamicClass)
assert cstats.alive() == 1
del instance
gc_collect(repeat=10)
assert cstats.alive() == 0

# Derived classes should work as well
class PythonDerivedDynamicClass(m.DynamicClass):
pass

for cls in m.CppDerivedDynamicClass, PythonDerivedDynamicClass:
for cls in (m.CppDerivedDynamicClass, PythonDerivedDynamicClass):
derived = cls()
derived.foobar = 100
assert derived.foobar == 100

assert cstats.alive() == 1
del derived
gc_collect(repeat=10)
assert cstats.alive() == 0


# https://foss.heptapod.net/pypy/pypy/-/issues/2447
@pytest.mark.xfail("env.PYPY")
@pytest.mark.xfail(
sys.version_info == (3, 14, 0, "beta", 1)
or sys.version_info == (3, 14, 0, "beta", 2),
reason="3.14.0b1/2 bug: https://github.com/python/cpython/issues/133912",
strict=False,
)
@pytest.mark.skipif("env.GRAALPY", reason="Cannot reliably trigger GC")
def test_cyclic_gc():
# One object references itself
Expand All @@ -364,6 +379,7 @@ def test_cyclic_gc():
cstats = ConstructorStats.get(m.DynamicClass)
assert cstats.alive() == 1
del instance
gc_collect(repeat=10)
assert cstats.alive() == 0

# Two object reference each other
Expand All @@ -374,6 +390,7 @@ def test_cyclic_gc():

assert cstats.alive() == 2
del i1, i2
gc_collect(repeat=10)
assert cstats.alive() == 0


Expand Down
2 changes: 1 addition & 1 deletion tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_roundtrip(cls_name):
sys.version_info == (3, 14, 0, "beta", 1)
or sys.version_info == (3, 14, 0, "beta", 2),
reason="3.14.0b1/2 bug: https://github.com/python/cpython/issues/133912",
strict=True,
strict=False,
),
),
"PickleableWithDictNew",
Expand Down
206 changes: 206 additions & 0 deletions tests/test_scoped_critical_section.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include <pybind11/critical_section.h>

#include "pybind11_tests.h"

#include <atomic>
#include <chrono>
#include <thread>
#include <utility>

#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<barrier>)
# define PYBIND11_HAS_BARRIER 1
# include <barrier>
#endif

namespace test_scoped_critical_section_ns {

void test_one_nullptr() { py::scoped_critical_section lock{py::handle{}}; }

void test_two_nullptrs() { py::scoped_critical_section lock{py::handle{}, py::handle{}}; }

void test_first_nullptr() {
py::dict d;
py::scoped_critical_section lock{py::handle{}, d};
}

void test_second_nullptr() {
py::dict d;
py::scoped_critical_section lock{d, py::handle{}};
}
Comment on lines +15 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I was suggesting to add tests, all I had in mind was something simple like this.

The additional tests look interesting though.

Could you please add comments to explain what the tests are for? (I'd ask my favorite LLM for suggestions, usually it's really quick that way.)


// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874
class BoolWrapper {
public:
explicit BoolWrapper(bool value) : value_{value} {}
bool get() const { return value_.load(std::memory_order_acquire); }
void set(bool value) { value_.store(value, std::memory_order_release); }

private:
std::atomic<bool> value_{false};
};

#if defined(PYBIND11_HAS_BARRIER)

void test_scoped_critical_section(const py::handle &cls) {
auto barrier = std::barrier(2);
auto bool_wrapper = cls(false);
bool output = false;

{
py::gil_scoped_release gil_release{};

std::thread t1([&]() {
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper};
barrier.arrive_and_wait();
auto *bw = bool_wrapper.cast<BoolWrapper *>();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
bw->set(true);
});

std::thread t2([&]() {
barrier.arrive_and_wait();
{
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper};
auto *bw = bool_wrapper.cast<BoolWrapper *>();
output = bw->get();
}
});

t1.join();
t2.join();
}

if (!output) {
throw std::runtime_error("Scoped critical section test failed: output is false");
}
}

void test_scoped_critical_section2(const py::handle &cls) {
auto barrier = std::barrier(3);
auto bool_wrapper1 = cls(false);
auto bool_wrapper2 = cls(false);
std::pair<bool, bool> output{false, false};

{
py::gil_scoped_release gil_release{};

std::thread t1([&]() {
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2};
barrier.arrive_and_wait();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
bw1->set(true);
bw2->set(true);
});

std::thread t2([&]() {
barrier.arrive_and_wait();
{
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper1};
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
output.first = bw1->get();
}
});

std::thread t3([&]() {
barrier.arrive_and_wait();
{
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper2};
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
output.second = bw2->get();
}
});

t1.join();
t2.join();
t3.join();
}

if (!output.first || !output.second) {
throw std::runtime_error(
"Scoped critical section test with two objects failed: output is false");
}
}

void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &cls) {
auto barrier = std::barrier(2);
auto bool_wrapper = cls(false);
bool output = false;

{
py::gil_scoped_release gil_release{};

std::thread t1([&]() {
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper, bool_wrapper};
barrier.arrive_and_wait();
std::this_thread::sleep_for(std::chrono::milliseconds(10));
auto *bw = bool_wrapper.cast<BoolWrapper *>();
bw->set(true);
});

std::thread t2([&]() {
barrier.arrive_and_wait();
{
py::gil_scoped_acquire ensure_tstate{};
py::scoped_critical_section lock{bool_wrapper};
auto *bw = bool_wrapper.cast<BoolWrapper *>();
output = bw->get();
}
});

t1.join();
t2.join();
}

if (!output) {
throw std::runtime_error(
"Scoped critical section test with same object failed: output is false");
}
}

#else

void test_scoped_critical_section(const py::handle &) {}
void test_scoped_critical_section2(const py::handle &) {}
void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &) {}

#endif

} // namespace test_scoped_critical_section_ns

TEST_SUBMODULE(scoped_critical_section, m) {
using namespace test_scoped_critical_section_ns;

m.def("test_one_nullptr", test_one_nullptr);
m.def("test_two_nullptrs", test_two_nullptrs);
m.def("test_first_nullptr", test_first_nullptr);
m.def("test_second_nullptr", test_second_nullptr);

auto BoolWrapperClass = py::class_<BoolWrapper>(m, "BoolWrapper")
.def(py::init<bool>())
.def("get", &BoolWrapper::get)
.def("set", &BoolWrapper::set);
auto BoolWrapperHandle = py::handle(BoolWrapperClass);
(void) BoolWrapperHandle.ptr(); // suppress unused variable warning

#ifdef PYBIND11_HAS_BARRIER
m.attr("has_barrier") = true;

m.def("test_scoped_critical_section",
[BoolWrapperHandle]() -> void { test_scoped_critical_section(BoolWrapperHandle); });
m.def("test_scoped_critical_section2",
[BoolWrapperHandle]() -> void { test_scoped_critical_section2(BoolWrapperHandle); });
m.def("test_scoped_critical_section2_same_object_no_deadlock", [BoolWrapperHandle]() -> void {
test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperHandle);
});
#else
m.attr("has_barrier") = false;
#endif
}
30 changes: 30 additions & 0 deletions tests/test_scoped_critical_section.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import pytest

from pybind11_tests import scoped_critical_section as m


def test_nullptr_combinations():
m.test_one_nullptr()
m.test_two_nullptrs()
m.test_first_nullptr()
m.test_second_nullptr()


@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
def test_scoped_critical_section() -> None:
for _ in range(64):
m.test_scoped_critical_section()


@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
def test_scoped_critical_section2() -> None:
for _ in range(64):
m.test_scoped_critical_section2()


@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
def test_scoped_critical_section2_same_object_no_deadlock() -> None:
for _ in range(64):
m.test_scoped_critical_section2_same_object_no_deadlock()
Loading