Skip to content

Commit 2ce1f8a

Browse files
committed
Simplify changes in pybind11/critical_section.h and add test_nullptr_combinations()
1 parent 58f155c commit 2ce1f8a

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

include/pybind11/critical_section.h

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,33 @@ PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
1313
class scoped_critical_section {
1414
public:
1515
#ifdef Py_GIL_DISABLED
16-
scoped_critical_section(handle obj1, handle obj2) : m_ptr1(obj1.ptr()), m_ptr2(obj2.ptr()) {
17-
if (m_ptr1 == nullptr) {
18-
std::swap(m_ptr1, m_ptr2);
19-
}
20-
if (m_ptr2 != nullptr) {
21-
PyCriticalSection2_Begin(&section2, m_ptr1, m_ptr2);
22-
} else if (m_ptr1 != nullptr) {
23-
PyCriticalSection_Begin(&section, m_ptr1);
16+
scoped_critical_section(handle obj1, handle obj2) {
17+
if (obj1) {
18+
if (obj2) {
19+
PyCriticalSection2_Begin(&section2, obj1.ptr(), obj2.ptr());
20+
rank = 2;
21+
} else {
22+
PyCriticalSection_Begin(&section, obj1.ptr());
23+
rank = 1;
24+
}
25+
} else if (obj2) {
26+
PyCriticalSection_Begin(&section, obj2.ptr());
27+
rank = 1;
2428
}
2529
}
2630

27-
explicit scoped_critical_section(handle obj) : m_ptr1(obj.ptr()) {
28-
if (m_ptr1 != nullptr) {
29-
PyCriticalSection_Begin(&section, m_ptr1);
31+
explicit scoped_critical_section(handle obj) {
32+
if (obj) {
33+
PyCriticalSection_Begin(&section, obj.ptr());
34+
rank = 1;
3035
}
3136
}
3237

3338
~scoped_critical_section() {
34-
if (m_ptr2 != nullptr) {
35-
PyCriticalSection2_End(&section2);
36-
} else if (m_ptr1 != nullptr) {
39+
if (rank == 1) {
3740
PyCriticalSection_End(&section);
41+
} else if (rank == 2) {
42+
PyCriticalSection2_End(&section2);
3843
}
3944
}
4045
#else
@@ -48,8 +53,7 @@ class scoped_critical_section {
4853

4954
private:
5055
#ifdef Py_GIL_DISABLED
51-
PyObject *m_ptr1{nullptr};
52-
PyObject *m_ptr2{nullptr};
56+
int rank{0};
5357
union {
5458
PyCriticalSection section;
5559
PyCriticalSection2 section2;

tests/test_scoped_critical_section.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414

1515
namespace test_scoped_critical_section_ns {
1616

17+
void test_one_nullptr() { py::scoped_critical_section lock{py::handle{}}; }
18+
19+
void test_two_nullptrs() { py::scoped_critical_section lock{py::handle{}, py::handle{}}; }
20+
21+
void test_first_nullptr() {
22+
py::dict d;
23+
py::scoped_critical_section lock{py::handle{}, d};
24+
}
25+
26+
void test_second_nullptr() {
27+
py::dict d;
28+
py::scoped_critical_section lock{d, py::handle{}};
29+
}
30+
1731
// Referenced test implementation: https://github.com/PyO3/pyo3/blob/v0.25.0/src/sync.rs#L874
1832
class BoolWrapper {
1933
public:
@@ -164,12 +178,10 @@ void test_scoped_critical_section2_same_object_no_deadlock(const py::handle &) {
164178
TEST_SUBMODULE(scoped_critical_section, m) {
165179
using namespace test_scoped_critical_section_ns;
166180

167-
m.attr("defined_THREAD_SANITIZER") =
168-
#if defined(THREAD_SANITIZER)
169-
true;
170-
#else
171-
false;
172-
#endif
181+
m.def("test_one_nullptr", test_one_nullptr);
182+
m.def("test_two_nullptrs", test_two_nullptrs);
183+
m.def("test_first_nullptr", test_first_nullptr);
184+
m.def("test_second_nullptr", test_second_nullptr);
173185

174186
auto BoolWrapperClass = py::class_<BoolWrapper>(m, "BoolWrapper")
175187
.def(py::init<bool>())

tests/test_scoped_critical_section.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
from pybind11_tests import scoped_critical_section as m
66

77

8+
def test_nullptr_combinations():
9+
m.test_one_nullptr()
10+
m.test_two_nullptrs()
11+
m.test_first_nullptr()
12+
m.test_second_nullptr()
13+
14+
815
@pytest.mark.skipif(not m.has_barrier, reason="no <barrier>")
916
def test_scoped_critical_section() -> None:
1017
for _ in range(64):

0 commit comments

Comments
 (0)