Skip to content

Commit 14f6f4d

Browse files
committed
Fix refleaks in the test.
I can't add new testcases to test_multiprocessing_forkserver itself, i had to put them within an existing _test_multiprocessing test class. I don't know why, but refleaks are fragile and that test suite is... rediculiously complicated with all that it does.
1 parent c83193d commit 14f6f4d

File tree

3 files changed

+65
-53
lines changed

3 files changed

+65
-53
lines changed

Lib/multiprocessing/forkserver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def connect_to_new_process(self, fds):
106106
wrapped_client, self._forkserver_authkey)
107107
finally:
108108
wrapped_client._detach()
109+
del wrapped_client
109110
reduction.sendfds(client, allfds)
110111
return parent_r, parent_w
111112
except:

Lib/test/_test_multiprocessing.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,8 +777,8 @@ def test_error_on_stdio_flush_2(self):
777777
finally:
778778
setattr(sys, stream_name, old_stream)
779779

780-
@classmethod
781-
def _sleep_and_set_event(self, evt, delay=0.0):
780+
@staticmethod
781+
def _sleep_and_set_event(evt, delay=0.0):
782782
time.sleep(delay)
783783
evt.set()
784784

@@ -829,6 +829,68 @@ def test_forkserver_sigkill(self):
829829
if os.name != 'nt':
830830
self.check_forkserver_death(signal.SIGKILL)
831831

832+
@staticmethod
833+
def _exit_process():
834+
sys.exit(0)
835+
836+
def test_forkserver_auth_is_enabled(self):
837+
if self.TYPE == "threads":
838+
self.skipTest(f"test not appropriate for {self.TYPE}")
839+
if multiprocessing.get_start_method() != "forkserver":
840+
self.skipTest("forkserver start method specific")
841+
842+
forkserver = multiprocessing.forkserver._forkserver
843+
forkserver.ensure_running()
844+
self.assertTrue(forkserver._forkserver_pid)
845+
authkey = forkserver._forkserver_authkey
846+
self.assertTrue(authkey)
847+
self.assertGreater(len(authkey), 15)
848+
addr = forkserver._forkserver_address
849+
self.assertTrue(addr)
850+
851+
# First, demonstrate that a raw auth handshake as Client makes
852+
# does not raise an error.
853+
client = multiprocessing.connection.Client(addr, authkey=authkey)
854+
client.close()
855+
856+
# That worked, now launch a quick process.
857+
proc = self.Process(target=self._exit_process)
858+
proc.start()
859+
proc.join()
860+
self.assertEqual(proc.exitcode, 0)
861+
862+
def test_forkserver_without_auth_fails(self):
863+
if self.TYPE == "threads":
864+
self.skipTest(f"test not appropriate for {self.TYPE}")
865+
if multiprocessing.get_start_method() != "forkserver":
866+
self.skipTest("forkserver start method specific")
867+
868+
forkserver = multiprocessing.forkserver._forkserver
869+
forkserver.ensure_running()
870+
self.assertTrue(forkserver._forkserver_pid)
871+
authkey_len = len(forkserver._forkserver_authkey)
872+
with unittest.mock.patch.object(
873+
forkserver, '_forkserver_authkey', None):
874+
# With no auth handshake, the connection this makes to the
875+
# forkserver will fail to do the file descriptor transfer
876+
# over the pipe as the forkserver is expecting auth.
877+
proc = self.Process(target=self._exit_process)
878+
with self.assertRaisesRegex(RuntimeError, 'not receive ack'):
879+
proc.start()
880+
del proc
881+
882+
# With an incorrect authkey we should get an auth rejection
883+
# rather than the above protocol error.
884+
forkserver._forkserver_authkey = b'T'*authkey_len
885+
proc = self.Process(target=self._exit_process)
886+
with self.assertRaises(multiprocessing.AuthenticationError):
887+
proc.start()
888+
del proc
889+
890+
# authkey restored, launching processes should work again.
891+
proc = self.Process(target=self._exit_process)
892+
proc.start()
893+
proc.join()
832894

833895
#
834896
#
Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import unittest
2-
from unittest import mock
32
import test._test_multiprocessing
43

5-
import os
64
import sys
75
from test import support
86

@@ -12,56 +10,7 @@
1210
if sys.platform == "win32":
1311
raise unittest.SkipTest("forkserver is not available on Windows")
1412

15-
import multiprocessing
16-
import multiprocessing.connection
17-
import multiprocessing.forkserver
18-
1913
test._test_multiprocessing.install_tests_in_module_dict(globals(), 'forkserver')
2014

21-
22-
class TestForkserverControlAuthentication(unittest.TestCase):
23-
def setUp(self):
24-
super().setUp()
25-
self.context = multiprocessing.get_context("forkserver")
26-
self.pool = self.context.Pool(processes=1, maxtasksperchild=4)
27-
self.assertEqual(self.pool.apply(eval, ("2+2",)), 4)
28-
self.forkserver = multiprocessing.forkserver._forkserver
29-
self.addr = self.forkserver._forkserver_address
30-
self.assertTrue(self.addr)
31-
self.authkey = self.forkserver._forkserver_authkey
32-
self.assertGreater(len(self.authkey), 15)
33-
self.assertTrue(self.forkserver._forkserver_pid)
34-
35-
def tearDown(self):
36-
self.pool.terminate()
37-
self.pool.join()
38-
super().tearDown()
39-
40-
def test_auth_works(self):
41-
"""FYI: An 'EOFError: ran out of input' from a worker is normal."""
42-
# First, demonstrate that a raw auth handshake as Client makes
43-
# does not raise.
44-
client = multiprocessing.connection.Client(
45-
self.addr, authkey=self.authkey)
46-
client.close()
47-
48-
# Now use forkserver code to do the same thing and more.
49-
status_r, data_w = self.forkserver.connect_to_new_process([])
50-
# It is normal for this to trigger an EOFError on stderr from the
51-
# process... it is expecting us to send over a pickle of a Process
52-
# instance to tell it what to do.
53-
# If the authentication handshake and subsequent file descriptor
54-
# sending dance had failed, an exception would've been raised.
55-
os.close(data_w)
56-
os.close(status_r)
57-
58-
def test_no_auth_fails(self):
59-
with mock.patch.object(self.forkserver, '_forkserver_authkey', None):
60-
# With no authkey set, the connection this makes will fail to
61-
# do the file descriptor transfer over the pipe.
62-
with self.assertRaisesRegex(RuntimeError, 'not receive ack'):
63-
status_r, data_w = self.forkserver.connect_to_new_process([])
64-
65-
6615
if __name__ == '__main__':
6716
unittest.main()

0 commit comments

Comments
 (0)