15
15
"""Execute Transactions Spec tests."""
16
16
from __future__ import annotations
17
17
18
+ import asyncio
18
19
import os
19
20
import sys
20
21
import time
23
24
sys .path [0 :0 ] = ["" ]
24
25
25
26
from test .asynchronous import AsyncIntegrationTest , client_knobs , unittest
26
- from test .asynchronous .utils_spec_runner import AsyncSpecTestCreator , SpecRunnerThread
27
- from test .pymongo_mocks import DummyMonitor
27
+ from test .asynchronous .pymongo_mocks import DummyMonitor
28
+ from test .asynchronous . utils_spec_runner import AsyncSpecTestCreator , SpecRunnerTask
28
29
from test .utils import (
29
30
CMAPListener ,
30
31
async_client_context ,
@@ -91,23 +92,23 @@ class AsyncTestCMAP(AsyncIntegrationTest):
91
92
92
93
# Test operations:
93
94
94
- def start (self , op ):
95
+ async def start (self , op ):
95
96
"""Run the 'start' thread operation."""
96
97
target = op ["target" ]
97
- thread = SpecRunnerThread (target )
98
- thread .start ()
98
+ thread = SpecRunnerTask (target )
99
+ await thread .start ()
99
100
self .targets [target ] = thread
100
101
101
- def wait (self , op ):
102
+ async def wait (self , op ):
102
103
"""Run the 'wait' operation."""
103
- time .sleep (op ["ms" ] / 1000.0 )
104
+ await asyncio .sleep (op ["ms" ] / 1000.0 )
104
105
105
- def wait_for_thread (self , op ):
106
+ async def wait_for_thread (self , op ):
106
107
"""Run the 'waitForThread' operation."""
107
108
target = op ["target" ]
108
109
thread = self .targets [target ]
109
- thread .stop ()
110
- thread .join ()
110
+ await thread .stop ()
111
+ await thread .join ()
111
112
if thread .exc :
112
113
raise thread .exc
113
114
self .assertFalse (thread .ops )
@@ -123,53 +124,53 @@ async def wait_for_event(self, op):
123
124
timeout = timeout ,
124
125
)
125
126
126
- def check_out (self , op ):
127
+ async def check_out (self , op ):
127
128
"""Run the 'checkOut' operation."""
128
129
label = op ["label" ]
129
- with self .pool .checkout () as conn :
130
+ async with self .pool .checkout () as conn :
130
131
# Call 'pin_cursor' so we can hold the socket.
131
132
conn .pin_cursor ()
132
133
if label :
133
134
self .labels [label ] = conn
134
135
else :
135
136
self .addAsyncCleanup (conn .close_conn , None )
136
137
137
- def check_in (self , op ):
138
+ async def check_in (self , op ):
138
139
"""Run the 'checkIn' operation."""
139
140
label = op ["connection" ]
140
141
conn = self .labels [label ]
141
- self .pool .checkin (conn )
142
+ await self .pool .checkin (conn )
142
143
143
- def ready (self , op ):
144
+ async def ready (self , op ):
144
145
"""Run the 'ready' operation."""
145
- self .pool .ready ()
146
+ await self .pool .ready ()
146
147
147
- def clear (self , op ):
148
+ async def clear (self , op ):
148
149
"""Run the 'clear' operation."""
149
- if "interruptInUseAsyncConnections " in op :
150
- self .pool .reset (interrupt_connections = op ["interruptInUseAsyncConnections " ])
150
+ if "interruptInUseConnections " in op :
151
+ await self .pool .reset (interrupt_connections = op ["interruptInUseConnections " ])
151
152
else :
152
- self .pool .reset ()
153
+ await self .pool .reset ()
153
154
154
155
async def close (self , op ):
155
- """Run the 'aclose ' operation."""
156
- await self .pool .aclose ()
156
+ """Run the 'close ' operation."""
157
+ await self .pool .close ()
157
158
158
- def run_operation (self , op ):
159
+ async def run_operation (self , op ):
159
160
"""Run a single operation in a test."""
160
161
op_name = camel_to_snake (op ["name" ])
161
162
thread = op ["thread" ]
162
163
meth = getattr (self , op_name )
163
164
if thread :
164
- self .targets [thread ].schedule (lambda : meth (op ))
165
+ await self .targets [thread ].schedule (lambda : meth (op ))
165
166
else :
166
- meth (op )
167
+ await meth (op )
167
168
168
- def run_operations (self , ops ):
169
+ async def run_operations (self , ops ):
169
170
"""Run a test's operations."""
170
171
for op in ops :
171
172
self ._ops .append (op )
172
- self .run_operation (op )
173
+ await self .run_operation (op )
173
174
174
175
def check_object (self , actual , expected ):
175
176
"""Assert that the actual object matches the expected object."""
@@ -215,10 +216,10 @@ async def _set_fail_point(self, client, command_args):
215
216
cmd .update (command_args )
216
217
await client .admin .command (cmd )
217
218
218
- def set_fail_point (self , command_args ):
219
+ async def set_fail_point (self , command_args ):
219
220
if not async_client_context .supports_failCommand_fail_point :
220
221
self .skipTest ("failCommand fail point must be supported" )
221
- self ._set_fail_point (self .client , command_args )
222
+ await self ._set_fail_point (self .client , command_args )
222
223
223
224
async def run_scenario (self , scenario_def , test ):
224
225
"""Run a CMAP spec test."""
@@ -231,7 +232,7 @@ async def run_scenario(self, scenario_def, test):
231
232
# Configure the fail point before creating the client.
232
233
if "failPoint" in test :
233
234
fp = test ["failPoint" ]
234
- self .set_fail_point (fp )
235
+ await self .set_fail_point (fp )
235
236
self .addAsyncCleanup (
236
237
self .set_fail_point , {"configureFailPoint" : fp ["configureFailPoint" ], "mode" : "off" }
237
238
)
@@ -254,16 +255,18 @@ async def run_scenario(self, scenario_def, test):
254
255
# PoolReadyEvents. Instead, update the initial state before
255
256
# opening the Topology.
256
257
td = async_client_context .client ._topology .description
257
- sd = td .server_descriptions ()[(async_client_context .host , async_client_context .port )]
258
+ sd = td .server_descriptions ()[
259
+ (await async_client_context .host , await async_client_context .port )
260
+ ]
258
261
client ._topology ._description = updated_topology_description (
259
262
client ._topology ._description , sd
260
263
)
261
264
# When backgroundThreadIntervalMS is negative we do not start the
262
265
# background thread to ensure it never runs.
263
266
if interval < 0 :
264
- client ._topology .open ()
267
+ await client ._topology .open ()
265
268
else :
266
- client ._get_topology ()
269
+ await client ._get_topology ()
267
270
self .pool = list (client ._topology ._servers .values ())[0 ].pool
268
271
269
272
# Map of target names to Thread objects.
@@ -273,21 +276,21 @@ async def run_scenario(self, scenario_def, test):
273
276
274
277
async def cleanup ():
275
278
for t in self .targets .values ():
276
- t .stop ()
279
+ await t .stop ()
277
280
for t in self .targets .values ():
278
- t .join (5 )
281
+ await t .join (5 )
279
282
for conn in self .labels .values ():
280
- await conn .aclose_conn (None )
283
+ conn .close_conn (None )
281
284
282
285
self .addAsyncCleanup (cleanup )
283
286
284
287
try :
285
288
if test ["error" ]:
286
289
with self .assertRaises (PyMongoError ) as ctx :
287
- self .run_operations (test ["operations" ])
290
+ await self .run_operations (test ["operations" ])
288
291
self .check_error (ctx .exception , test ["error" ])
289
292
else :
290
- self .run_operations (test ["operations" ])
293
+ await self .run_operations (test ["operations" ])
291
294
292
295
self .check_events (test ["events" ], test ["ignore" ])
293
296
except Exception :
@@ -452,8 +455,8 @@ async def test_close_leaves_pool_unpaused(self):
452
455
453
456
454
457
def create_test (scenario_def , test , name ):
455
- def run_scenario (self ):
456
- self .run_scenario (scenario_def , test )
458
+ async def run_scenario (self ):
459
+ await self .run_scenario (scenario_def , test )
457
460
458
461
return run_scenario
459
462
@@ -468,9 +471,8 @@ async def tests(self, scenario_def):
468
471
return [scenario_def ]
469
472
470
473
471
- if _IS_SYNC :
472
- test_creator = CMAPSpecTestCreator (create_test , AsyncTestCMAP , AsyncTestCMAP .TEST_PATH )
473
- test_creator .create_tests ()
474
+ test_creator = CMAPSpecTestCreator (create_test , AsyncTestCMAP , AsyncTestCMAP .TEST_PATH )
475
+ test_creator .create_tests ()
474
476
475
477
476
478
if __name__ == "__main__" :
0 commit comments