15
15
"""Test that pymongo is fork safe."""
16
16
17
17
import os
18
+ import sys
19
+ import unittest
18
20
from multiprocessing import Pipe
21
+
22
+ from bson .objectid import ObjectId
23
+
24
+ sys .path [0 :0 ] = ["" ]
25
+
19
26
from test import IntegrationTest
20
27
from test .utils import (
21
28
ExceptionCatchingThread ,
22
29
is_greenthread_patched ,
23
30
rs_or_single_client ,
24
31
)
25
- from unittest import skipIf
26
32
27
- from bson .objectid import ObjectId
28
33
29
-
30
- @skipIf (
34
+ @unittest .skipIf (
31
35
not hasattr (os , "register_at_fork" ), "register_at_fork not available in this version of Python"
32
36
)
33
- @skipIf (
37
+ @unittest . skipIf (
34
38
is_greenthread_patched (),
35
39
"gevent and eventlet do not support POSIX-style forking." ,
36
40
)
@@ -40,20 +44,26 @@ def test_lock_client(self):
40
44
# Parent => All locks should be as before the fork.
41
45
# Child => All locks should be reset.
42
46
with self .client ._MongoClient__lock :
43
- with self .fork () as pid :
44
- if pid == 0 : # Child
45
- self .client .admin .command ("ping" )
47
+
48
+ def target ():
49
+ self .client .admin .command ("ping" )
50
+
51
+ with self .fork (target ):
52
+ pass
46
53
self .client .admin .command ("ping" )
47
54
48
55
def test_lock_object_id (self ):
49
56
# Forks the client with ObjectId's _inc_lock locked.
50
57
# Parent => _inc_lock should remain locked.
51
58
# Child => _inc_lock should be unlocked.
52
59
with ObjectId ._inc_lock :
53
- with self .fork () as pid :
54
- if pid == 0 : # Child
55
- self .assertFalse (ObjectId ._inc_lock .locked ())
56
- self .assertTrue (ObjectId ())
60
+
61
+ def target ():
62
+ self .assertFalse (ObjectId ._inc_lock .locked ())
63
+ self .assertTrue (ObjectId ())
64
+
65
+ with self .fork (target ):
66
+ pass
57
67
58
68
def test_topology_reset (self ):
59
69
# Tests that topologies are different from each other.
@@ -63,22 +73,23 @@ def test_topology_reset(self):
63
73
parent_conn , child_conn = Pipe ()
64
74
init_id = self .client ._topology ._pid
65
75
parent_cursor_exc = self .client ._kill_cursors_executor
66
- with self .fork () as pid :
67
- if pid == 0 : # Child
68
- self .client .admin .command ("ping" )
69
- child_conn .send (self .client ._topology ._pid )
70
- child_conn .send (
71
- (
72
- parent_cursor_exc != self .client ._kill_cursors_executor ,
73
- "client._kill_cursors_executor was not reinitialized" ,
74
- )
76
+
77
+ def target ():
78
+ self .client .admin .command ("ping" )
79
+ child_conn .send (self .client ._topology ._pid )
80
+ child_conn .send (
81
+ (
82
+ parent_cursor_exc != self .client ._kill_cursors_executor ,
83
+ "client._kill_cursors_executor was not reinitialized" ,
75
84
)
76
- else : # Parent
77
- self .assertEqual (self .client ._topology ._pid , init_id )
78
- child_id = parent_conn .recv ()
79
- self .assertNotEqual (child_id , init_id )
80
- passed , msg = parent_conn .recv ()
81
- self .assertTrue (passed , msg )
85
+ )
86
+
87
+ with self .fork (target ):
88
+ self .assertEqual (self .client ._topology ._pid , init_id )
89
+ child_id = parent_conn .recv ()
90
+ self .assertNotEqual (child_id , init_id )
91
+ passed , msg = parent_conn .recv ()
92
+ self .assertTrue (passed , msg )
82
93
83
94
def test_many_threaded (self ):
84
95
# Fork randomly while doing operations.
@@ -106,10 +117,13 @@ def action(client):
106
117
rc = self .clients [i % len (self .clients )]
107
118
if i % 50 == 0 and self .fork :
108
119
# Fork
109
- with self .runner .fork () as pid :
110
- if pid == 0 : # Child
111
- for c in self .clients :
112
- action (c )
120
+ def target ():
121
+ for c_ in self .clients :
122
+ action (c_ )
123
+ c_ .close ()
124
+
125
+ with self .runner .fork (target = target ) as proc :
126
+ self .runner .assertTrue (proc .pid )
113
127
action (rc )
114
128
115
129
threads = [ForkThread (self , clients ) for _ in range (10 )]
@@ -125,3 +139,7 @@ def action(client):
125
139
126
140
for c in clients :
127
141
c .close ()
142
+
143
+
144
+ if __name__ == "__main__" :
145
+ unittest .main ()
0 commit comments