Skip to content

Commit 1e6b4a4

Browse files
ShaneHarveyBen Warner
andauthored
PYTHON-3406 Log traceback when fork() test encounters a deadlock (#1045)
Co-authored-by: Ben Warner <[email protected]>
1 parent 7f19186 commit 1e6b4a4

File tree

3 files changed

+78
-52
lines changed

3 files changed

+78
-52
lines changed

test/__init__.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
import base64
1919
import gc
20+
import multiprocessing
2021
import os
22+
import signal
2123
import socket
2224
import sys
2325
import threading
@@ -43,7 +45,7 @@
4345
from contextlib import contextmanager
4446
from functools import wraps
4547
from test.version import Version
46-
from typing import Dict, Generator, no_type_check
48+
from typing import Callable, Dict, Generator, no_type_check
4749
from unittest import SkipTest
4850
from urllib.parse import quote_plus
4951

@@ -999,31 +1001,33 @@ def fail_point(self, command_args):
9991001
)
10001002

10011003
@contextmanager
1002-
def fork(self) -> Generator[int, None, None]:
1004+
def fork(
1005+
self, target: Callable, timeout: float = 60
1006+
) -> Generator[multiprocessing.Process, None, None]:
10031007
"""Helper for tests that use os.fork()
10041008
10051009
Use in a with statement:
10061010
1007-
with self.fork() as pid:
1008-
if pid == 0: # Child
1009-
pass
1010-
else: # Parent
1011-
pass
1011+
with self.fork(target=lambda: print('in child')) as proc:
1012+
self.assertTrue(proc.pid) # Child process was started
10121013
"""
1013-
pid = os.fork()
1014-
in_child = pid == 0
1014+
ctx = multiprocessing.get_context("fork")
1015+
proc = ctx.Process(target=target)
1016+
proc.start()
10151017
try:
1016-
yield pid
1017-
except:
1018-
if in_child:
1019-
traceback.print_exc()
1020-
os._exit(1)
1021-
raise
1018+
yield proc # type: ignore
10221019
finally:
1023-
if in_child:
1024-
os._exit(0)
1025-
# In parent, assert child succeeded.
1026-
self.assertEqual(0, os.waitpid(pid, 0)[1])
1020+
proc.join(timeout)
1021+
pid = proc.pid
1022+
assert pid
1023+
if proc.exitcode is None:
1024+
# If it failed, SIGINT to get traceback and wait 10s.
1025+
os.kill(pid, signal.SIGINT)
1026+
proc.join(10)
1027+
proc.kill()
1028+
proc.join(1)
1029+
self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?")
1030+
self.assertEqual(proc.exitcode, 0)
10271031

10281032

10291033
class IntegrationTest(PyMongoTestCase):

test/test_encryption.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,13 @@ def test_use_after_close(self):
341341
def test_fork(self):
342342
opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys")
343343
client = rs_or_single_client(auto_encryption_opts=opts)
344-
with self.fork():
344+
self.addCleanup(client.close)
345+
346+
def target():
345347
client.admin.command("ping")
346-
client.close()
348+
349+
with self.fork(target):
350+
target()
347351

348352

349353
class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest):

test/test_fork.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,26 @@
1515
"""Test that pymongo is fork safe."""
1616

1717
import os
18+
import sys
19+
import unittest
1820
from multiprocessing import Pipe
21+
22+
from bson.objectid import ObjectId
23+
24+
sys.path[0:0] = [""]
25+
1926
from test import IntegrationTest
2027
from test.utils import (
2128
ExceptionCatchingThread,
2229
is_greenthread_patched,
2330
rs_or_single_client,
2431
)
25-
from unittest import skipIf
2632

27-
from bson.objectid import ObjectId
2833

29-
30-
@skipIf(
34+
@unittest.skipIf(
3135
not hasattr(os, "register_at_fork"), "register_at_fork not available in this version of Python"
3236
)
33-
@skipIf(
37+
@unittest.skipIf(
3438
is_greenthread_patched(),
3539
"gevent and eventlet do not support POSIX-style forking.",
3640
)
@@ -40,20 +44,26 @@ def test_lock_client(self):
4044
# Parent => All locks should be as before the fork.
4145
# Child => All locks should be reset.
4246
with self.client._MongoClient__lock:
43-
with self.fork() as pid:
44-
if pid == 0: # Child
45-
self.client.admin.command("ping")
47+
48+
def target():
49+
self.client.admin.command("ping")
50+
51+
with self.fork(target):
52+
pass
4653
self.client.admin.command("ping")
4754

4855
def test_lock_object_id(self):
4956
# Forks the client with ObjectId's _inc_lock locked.
5057
# Parent => _inc_lock should remain locked.
5158
# Child => _inc_lock should be unlocked.
5259
with ObjectId._inc_lock:
53-
with self.fork() as pid:
54-
if pid == 0: # Child
55-
self.assertFalse(ObjectId._inc_lock.locked())
56-
self.assertTrue(ObjectId())
60+
61+
def target():
62+
self.assertFalse(ObjectId._inc_lock.locked())
63+
self.assertTrue(ObjectId())
64+
65+
with self.fork(target):
66+
pass
5767

5868
def test_topology_reset(self):
5969
# Tests that topologies are different from each other.
@@ -63,22 +73,23 @@ def test_topology_reset(self):
6373
parent_conn, child_conn = Pipe()
6474
init_id = self.client._topology._pid
6575
parent_cursor_exc = self.client._kill_cursors_executor
66-
with self.fork() as pid:
67-
if pid == 0: # Child
68-
self.client.admin.command("ping")
69-
child_conn.send(self.client._topology._pid)
70-
child_conn.send(
71-
(
72-
parent_cursor_exc != self.client._kill_cursors_executor,
73-
"client._kill_cursors_executor was not reinitialized",
74-
)
76+
77+
def target():
78+
self.client.admin.command("ping")
79+
child_conn.send(self.client._topology._pid)
80+
child_conn.send(
81+
(
82+
parent_cursor_exc != self.client._kill_cursors_executor,
83+
"client._kill_cursors_executor was not reinitialized",
7584
)
76-
else: # Parent
77-
self.assertEqual(self.client._topology._pid, init_id)
78-
child_id = parent_conn.recv()
79-
self.assertNotEqual(child_id, init_id)
80-
passed, msg = parent_conn.recv()
81-
self.assertTrue(passed, msg)
85+
)
86+
87+
with self.fork(target):
88+
self.assertEqual(self.client._topology._pid, init_id)
89+
child_id = parent_conn.recv()
90+
self.assertNotEqual(child_id, init_id)
91+
passed, msg = parent_conn.recv()
92+
self.assertTrue(passed, msg)
8293

8394
def test_many_threaded(self):
8495
# Fork randomly while doing operations.
@@ -106,10 +117,13 @@ def action(client):
106117
rc = self.clients[i % len(self.clients)]
107118
if i % 50 == 0 and self.fork:
108119
# Fork
109-
with self.runner.fork() as pid:
110-
if pid == 0: # Child
111-
for c in self.clients:
112-
action(c)
120+
def target():
121+
for c_ in self.clients:
122+
action(c_)
123+
c_.close()
124+
125+
with self.runner.fork(target=target) as proc:
126+
self.runner.assertTrue(proc.pid)
113127
action(rc)
114128

115129
threads = [ForkThread(self, clients) for _ in range(10)]
@@ -125,3 +139,7 @@ def action(client):
125139

126140
for c in clients:
127141
c.close()
142+
143+
144+
if __name__ == "__main__":
145+
unittest.main()

0 commit comments

Comments
 (0)