33#include " pybind11_tests.h"
44
55#include < atomic>
6- #include < cassert >
6+ #include < chrono >
77#include < thread>
8+ #include < utility>
89
9- #if defined(__has_include) && __has_include(<barrier>)
10+ #if defined(PYBIND11_CPP20) && defined( __has_include) && __has_include(<barrier>)
1011# define PYBIND11_HAS_BARRIER 1
1112# include < barrier>
1213#endif
@@ -19,88 +20,97 @@ class BoolWrapper {
1920 void set (bool value) { value_.store (value, std::memory_order_release); }
2021
2122private:
22- std::atomic<bool > value_;
23+ std::atomic<bool > value_{ false } ;
2324};
2425
2526#ifdef PYBIND11_HAS_BARRIER
26- void test_scoped_critical_section (py::class_<BoolWrapper> &cls) {
27+ bool test_scoped_critical_section (const py::handle &cls) {
2728 auto barrier = std::barrier (2 );
2829 auto bool_wrapper = cls (false );
30+ bool output = false ;
2931
3032 std::thread t1 ([&]() {
3133 py::scoped_critical_section lock{bool_wrapper};
3234 barrier.arrive_and_wait ();
33- auto bw = bool_wrapper.cast <std::shared_ptr< BoolWrapper> >();
35+ auto * bw = bool_wrapper.cast <BoolWrapper * >();
3436 std::this_thread::sleep_for (std::chrono::milliseconds (10 ));
3537 bw->set (true );
3638 });
3739
3840 std::thread t2 ([&]() {
3941 barrier.arrive_and_wait ();
4042 py::scoped_critical_section lock{bool_wrapper};
41- auto bw = bool_wrapper.cast <std::shared_ptr< BoolWrapper> >();
42- assert ( bw->get () == true );
43+ auto * bw = bool_wrapper.cast <BoolWrapper * >();
44+ output = bw->get ();
4345 });
4446
4547 t1.join ();
4648 t2.join ();
49+
50+ return output;
4751}
4852
49- void test_scoped_critical_section2 (py::class_<BoolWrapper> &cls) {
53+ std::pair< bool , bool > test_scoped_critical_section2 (const py::handle &cls) {
5054 auto barrier = std::barrier (3 );
5155 auto bool_wrapper1 = cls (false );
5256 auto bool_wrapper2 = cls (false );
57+ std::pair<bool , bool > output{false , false };
5358
5459 std::thread t1 ([&]() {
5560 py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2};
5661 barrier.arrive_and_wait ();
5762 std::this_thread::sleep_for (std::chrono::milliseconds (10 ));
58- auto bw1 = bool_wrapper1.cast <std::shared_ptr< BoolWrapper> >();
59- auto bw2 = bool_wrapper2.cast <std::shared_ptr< BoolWrapper> >();
63+ auto * bw1 = bool_wrapper1.cast <BoolWrapper * >();
64+ auto * bw2 = bool_wrapper2.cast <BoolWrapper * >();
6065 bw1->set (true );
6166 bw2->set (true );
6267 });
6368
6469 std::thread t2 ([&]() {
6570 barrier.arrive_and_wait ();
6671 py::scoped_critical_section lock{bool_wrapper1};
67- auto bw1 = bool_wrapper1.cast <std::shared_ptr< BoolWrapper> >();
68- assert ( bw1->get () == true );
72+ auto * bw1 = bool_wrapper1.cast <BoolWrapper * >();
73+ output. first = bw1->get ();
6974 });
7075
7176 std::thread t3 ([&]() {
7277 barrier.arrive_and_wait ();
7378 py::scoped_critical_section lock{bool_wrapper2};
74- auto bw2 = bool_wrapper2.cast <std::shared_ptr< BoolWrapper> >();
75- assert ( bw2->get () == true );
79+ auto * bw2 = bool_wrapper2.cast <BoolWrapper * >();
80+ output. second = bw2->get ();
7681 });
7782
7883 t1.join ();
7984 t2.join ();
8085 t3.join ();
86+
87+ return output;
8188}
8289
83- void test_scoped_critical_section2_same_object_no_deadlock (py::class_<BoolWrapper> &cls) {
90+ bool test_scoped_critical_section2_same_object_no_deadlock (const py::handle &cls) {
8491 auto barrier = std::barrier (2 );
8592 auto bool_wrapper = cls (false );
93+ bool output = false ;
8694
8795 std::thread t1 ([&]() {
8896 py::scoped_critical_section lock{bool_wrapper, bool_wrapper};
8997 barrier.arrive_and_wait ();
9098 std::this_thread::sleep_for (std::chrono::milliseconds (10 ));
91- auto bw = bool_wrapper.cast <std::shared_ptr< BoolWrapper> >();
99+ auto * bw = bool_wrapper.cast <BoolWrapper * >();
92100 bw->set (true );
93101 });
94102
95103 std::thread t2 ([&]() {
96104 barrier.arrive_and_wait ();
97105 py::scoped_critical_section lock{bool_wrapper};
98- auto bw = bool_wrapper.cast <std::shared_ptr< BoolWrapper> >();
99- assert ( bw->get () == true );
106+ auto * bw = bool_wrapper.cast <BoolWrapper * >();
107+ output = bw->get ();
100108 });
101109
102110 t1.join ();
103111 t2.join ();
112+
113+ return output;
104114}
105115#endif
106116
@@ -116,16 +126,19 @@ TEST_SUBMODULE(scoped_critical_section, m) {
116126 .def (py::init<bool >())
117127 .def (" get" , &BoolWrapper::get)
118128 .def (" set" , &BoolWrapper::set);
129+ auto BoolWrapperHandle = py::handle (BoolWrapperClass);
119130
120131#ifdef PYBIND11_HAS_BARRIER
121132 m.attr (" has_barrier" ) = true ;
122133
123- m.def (" test_scoped_critical_section" ,
124- [&]() -> void { test_scoped_critical_section (BoolWrapperClass); });
125- m.def (" test_scoped_critical_section2" ,
126- [&]() -> void { test_scoped_critical_section2 (BoolWrapperClass); });
127- m.def (" test_scoped_critical_section2_same_object_no_deadlock" , [&]() -> void {
128- test_scoped_critical_section2_same_object_no_deadlock (BoolWrapperClass);
134+ m.def (" test_scoped_critical_section" , [BoolWrapperHandle]() -> bool {
135+ return test_scoped_critical_section (BoolWrapperHandle);
136+ });
137+ m.def (" test_scoped_critical_section2" , [BoolWrapperHandle]() -> std::pair<bool , bool > {
138+ return test_scoped_critical_section2 (BoolWrapperHandle);
139+ });
140+ m.def (" test_scoped_critical_section2_same_object_no_deadlock" , [BoolWrapperHandle]() -> bool {
141+ return test_scoped_critical_section2_same_object_no_deadlock (BoolWrapperHandle);
129142 });
130143#else
131144 m.attr (" has_barrier" ) = false ;
0 commit comments