@@ -220,35 +220,43 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are
220
220
221
221
@async_client_context .require_failCommand_fail_point
222
222
async def test_retryable_reads_are_retried_on_the_same_implicit_session (self ):
223
- fail_command = {
224
- "configureFailPoint" : "failCommand" ,
225
- "mode" : {"times" : 1 },
226
- "data" : {"failCommands" : ["count" ], "errorCode" : 6 },
227
- }
228
-
229
223
listener = OvertCommandListener ()
230
224
client = await self .async_rs_or_single_client (
231
225
directConnection = False ,
232
226
event_listeners = [listener ],
233
227
retryReads = True ,
234
228
)
235
229
236
- await async_set_fail_point (client , fail_command )
237
-
238
- await client .t .t .estimated_document_count ()
230
+ commands = [
231
+ ("aggregate" , lambda : client .t .t .count_documents ({})),
232
+ ("aggregate" , lambda : client .t .t .aggregate ([{"$match" : {}}])),
233
+ ("count" , lambda : client .t .t .estimated_document_count ()),
234
+ ("distinct" , lambda : client .t .t .distinct ("x" )),
235
+ ("find" , lambda : client .t .t .find_one ({})),
236
+ ("listDatabases" , lambda : client .list_databases ()),
237
+ ("listCollections" , lambda : client .t .list_collections ()),
238
+ ("listIndexes" , lambda : client .t .t .list_indexes ()),
239
+ ]
239
240
240
- # Disable failpoint.
241
- fail_command ["mode" ] = "off"
242
- await async_set_fail_point (client , fail_command )
241
+ for command_name , operation in commands :
242
+ listener .reset ()
243
+ fail_command = {
244
+ "configureFailPoint" : "failCommand" ,
245
+ "mode" : {"times" : 1 },
246
+ "data" : {"failCommands" : [command_name ], "errorCode" : 6 },
247
+ }
243
248
244
- # Assert that both events occurred on the same session.
245
- lsids = [
246
- event .command ["lsid" ]
247
- for event in listener .started_events
248
- if event .command_name == "count"
249
- ]
250
- self .assertEqual (len (lsids ), 2 )
251
- self .assertEqual (lsids [0 ], lsids [1 ])
249
+ async with self .fail_point (fail_command ):
250
+ await operation ()
251
+
252
+ # Assert that both events occurred on the same session.
253
+ lsids = [
254
+ event .command ["lsid" ]
255
+ for event in listener .started_events
256
+ if event .command_name == command_name
257
+ ]
258
+ self .assertEqual (len (lsids ), 2 )
259
+ self .assertEqual (lsids [0 ], lsids [1 ])
252
260
253
261
254
262
if __name__ == "__main__" :
0 commit comments