Skip to content

Commit 37c3d7f

Browse files
XuehaiPanhenryiii
authored andcommitted
fix: fix segmentation fault in test
1 parent 09f3a0c commit 37c3d7f

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

tests/test_scoped_critical_section.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
#include "pybind11_tests.h"
44

55
#include <atomic>
6-
#include <cassert>
6+
#include <chrono>
77
#include <thread>
8+
#include <utility>
89

9-
#if defined(__has_include) && __has_include(<barrier>)
10+
#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<barrier>)
1011
# define PYBIND11_HAS_BARRIER 1
1112
# include <barrier>
1213
#endif
@@ -19,88 +20,97 @@ class BoolWrapper {
1920
void set(bool value) { value_.store(value, std::memory_order_release); }
2021

2122
private:
22-
std::atomic<bool> value_;
23+
std::atomic<bool> value_{false};
2324
};
2425

2526
#ifdef PYBIND11_HAS_BARRIER
26-
void test_scoped_critical_section(py::class_<BoolWrapper> &cls) {
27+
bool test_scoped_critical_section(const py::handle &cls) {
2728
auto barrier = std::barrier(2);
2829
auto bool_wrapper = cls(false);
30+
bool output = false;
2931

3032
std::thread t1([&]() {
3133
py::scoped_critical_section lock{bool_wrapper};
3234
barrier.arrive_and_wait();
33-
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
35+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
3436
std::this_thread::sleep_for(std::chrono::milliseconds(10));
3537
bw->set(true);
3638
});
3739

3840
std::thread t2([&]() {
3941
barrier.arrive_and_wait();
4042
py::scoped_critical_section lock{bool_wrapper};
41-
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
42-
assert(bw->get() == true);
43+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
44+
output = bw->get();
4345
});
4446

4547
t1.join();
4648
t2.join();
49+
50+
return output;
4751
}
4852

49-
void test_scoped_critical_section2(py::class_<BoolWrapper> &cls) {
53+
std::pair<bool, bool> test_scoped_critical_section2(const py::handle &cls) {
5054
auto barrier = std::barrier(3);
5155
auto bool_wrapper1 = cls(false);
5256
auto bool_wrapper2 = cls(false);
57+
std::pair<bool, bool> output{false, false};
5358

5459
std::thread t1([&]() {
5560
py::scoped_critical_section lock{bool_wrapper1, bool_wrapper2};
5661
barrier.arrive_and_wait();
5762
std::this_thread::sleep_for(std::chrono::milliseconds(10));
58-
auto bw1 = bool_wrapper1.cast<std::shared_ptr<BoolWrapper>>();
59-
auto bw2 = bool_wrapper2.cast<std::shared_ptr<BoolWrapper>>();
63+
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
64+
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
6065
bw1->set(true);
6166
bw2->set(true);
6267
});
6368

6469
std::thread t2([&]() {
6570
barrier.arrive_and_wait();
6671
py::scoped_critical_section lock{bool_wrapper1};
67-
auto bw1 = bool_wrapper1.cast<std::shared_ptr<BoolWrapper>>();
68-
assert(bw1->get() == true);
72+
auto *bw1 = bool_wrapper1.cast<BoolWrapper *>();
73+
output.first = bw1->get();
6974
});
7075

7176
std::thread t3([&]() {
7277
barrier.arrive_and_wait();
7378
py::scoped_critical_section lock{bool_wrapper2};
74-
auto bw2 = bool_wrapper2.cast<std::shared_ptr<BoolWrapper>>();
75-
assert(bw2->get() == true);
79+
auto *bw2 = bool_wrapper2.cast<BoolWrapper *>();
80+
output.second = bw2->get();
7681
});
7782

7883
t1.join();
7984
t2.join();
8085
t3.join();
86+
87+
return output;
8188
}
8289

83-
void test_scoped_critical_section2_same_object_no_deadlock(py::class_<BoolWrapper> &cls) {
90+
bool test_scoped_critical_section2_same_object_no_deadlock(const py::handle &cls) {
8491
auto barrier = std::barrier(2);
8592
auto bool_wrapper = cls(false);
93+
bool output = false;
8694

8795
std::thread t1([&]() {
8896
py::scoped_critical_section lock{bool_wrapper, bool_wrapper};
8997
barrier.arrive_and_wait();
9098
std::this_thread::sleep_for(std::chrono::milliseconds(10));
91-
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
99+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
92100
bw->set(true);
93101
});
94102

95103
std::thread t2([&]() {
96104
barrier.arrive_and_wait();
97105
py::scoped_critical_section lock{bool_wrapper};
98-
auto bw = bool_wrapper.cast<std::shared_ptr<BoolWrapper>>();
99-
assert(bw->get() == true);
106+
auto *bw = bool_wrapper.cast<BoolWrapper *>();
107+
output = bw->get();
100108
});
101109

102110
t1.join();
103111
t2.join();
112+
113+
return output;
104114
}
105115
#endif
106116

@@ -116,16 +126,19 @@ TEST_SUBMODULE(scoped_critical_section, m) {
116126
.def(py::init<bool>())
117127
.def("get", &BoolWrapper::get)
118128
.def("set", &BoolWrapper::set);
129+
auto BoolWrapperHandle = py::handle(BoolWrapperClass);
119130

120131
#ifdef PYBIND11_HAS_BARRIER
121132
m.attr("has_barrier") = true;
122133

123-
m.def("test_scoped_critical_section",
124-
[&]() -> void { test_scoped_critical_section(BoolWrapperClass); });
125-
m.def("test_scoped_critical_section2",
126-
[&]() -> void { test_scoped_critical_section2(BoolWrapperClass); });
127-
m.def("test_scoped_critical_section2_same_object_no_deadlock", [&]() -> void {
128-
test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperClass);
134+
m.def("test_scoped_critical_section", [BoolWrapperHandle]() -> bool {
135+
return test_scoped_critical_section(BoolWrapperHandle);
136+
});
137+
m.def("test_scoped_critical_section2", [BoolWrapperHandle]() -> std::pair<bool, bool> {
138+
return test_scoped_critical_section2(BoolWrapperHandle);
139+
});
140+
m.def("test_scoped_critical_section2_same_object_no_deadlock", [BoolWrapperHandle]() -> bool {
141+
return test_scoped_critical_section2_same_object_no_deadlock(BoolWrapperHandle);
129142
});
130143
#else
131144
m.attr("has_barrier") = false;

tests/test_scoped_critical_section.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
99
def test_scoped_critical_section() -> None:
1010
for _ in range(64):
11-
m.test_scoped_critical_section()
11+
assert m.test_scoped_critical_section() is True
1212

1313

1414
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
1515
def test_scoped_critical_section2() -> None:
1616
for _ in range(64):
17-
assert m.test_scoped_critical_section2()
17+
assert m.test_scoped_critical_section2() == (True, True)
1818

1919

2020
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
2121
def test_scoped_critical_section2_same_object_no_deadlock() -> None:
2222
for _ in range(64):
23-
m.test_scoped_critical_section2_same_object_no_deadlock()
23+
assert m.test_scoped_critical_section2_same_object_no_deadlock() is True

0 commit comments

Comments
 (0)