Skip to content

Commit 382a94c

Browse files
committed
Add other retryable reads to test
1 parent 979208b commit 382a94c

File tree

2 files changed

+56
-40
lines changed

2 files changed

+56
-40
lines changed

test/asynchronous/test_retryable_reads.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -220,35 +220,43 @@ async def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are
220220

221221
@async_client_context.require_failCommand_fail_point
222222
async def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
223-
fail_command = {
224-
"configureFailPoint": "failCommand",
225-
"mode": {"times": 1},
226-
"data": {"failCommands": ["count"], "errorCode": 6},
227-
}
228-
229223
listener = OvertCommandListener()
230224
client = await self.async_rs_or_single_client(
231225
directConnection=False,
232226
event_listeners=[listener],
233227
retryReads=True,
234228
)
235229

236-
await async_set_fail_point(client, fail_command)
237-
238-
await client.t.t.estimated_document_count()
230+
commands = [
231+
("aggregate", lambda: client.t.t.count_documents({})),
232+
("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])),
233+
("count", lambda: client.t.t.estimated_document_count()),
234+
("distinct", lambda: client.t.t.distinct("x")),
235+
("find", lambda: client.t.t.find_one({})),
236+
("listDatabases", lambda: client.list_databases()),
237+
("listCollections", lambda: client.t.list_collections()),
238+
("listIndexes", lambda: client.t.t.list_indexes()),
239+
]
239240

240-
# Disable failpoint.
241-
fail_command["mode"] = "off"
242-
await async_set_fail_point(client, fail_command)
241+
for command_name, operation in commands:
242+
listener.reset()
243+
fail_command = {
244+
"configureFailPoint": "failCommand",
245+
"mode": {"times": 1},
246+
"data": {"failCommands": [command_name], "errorCode": 6},
247+
}
243248

244-
# Assert that both events occurred on the same session.
245-
lsids = [
246-
event.command["lsid"]
247-
for event in listener.started_events
248-
if event.command_name == "count"
249-
]
250-
self.assertEqual(len(lsids), 2)
251-
self.assertEqual(lsids[0], lsids[1])
249+
async with self.fail_point(fail_command):
250+
await operation()
251+
252+
# Assert that both events occurred on the same session.
253+
lsids = [
254+
event.command["lsid"]
255+
for event in listener.started_events
256+
if event.command_name == command_name
257+
]
258+
self.assertEqual(len(lsids), 2)
259+
self.assertEqual(lsids[0], lsids[1])
252260

253261

254262
if __name__ == "__main__":

test/test_retryable_reads.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -218,35 +218,43 @@ def test_retryable_reads_are_retried_on_the_same_mongos_when_no_others_are_avail
218218

219219
@client_context.require_failCommand_fail_point
220220
def test_retryable_reads_are_retried_on_the_same_implicit_session(self):
221-
fail_command = {
222-
"configureFailPoint": "failCommand",
223-
"mode": {"times": 1},
224-
"data": {"failCommands": ["count"], "errorCode": 6},
225-
}
226-
227221
listener = OvertCommandListener()
228222
client = self.rs_or_single_client(
229223
directConnection=False,
230224
event_listeners=[listener],
231225
retryReads=True,
232226
)
233227

234-
set_fail_point(client, fail_command)
235-
236-
client.t.t.estimated_document_count()
228+
commands = [
229+
("aggregate", lambda: client.t.t.count_documents({})),
230+
("aggregate", lambda: client.t.t.aggregate([{"$match": {}}])),
231+
("count", lambda: client.t.t.estimated_document_count()),
232+
("distinct", lambda: client.t.t.distinct("x")),
233+
("find", lambda: client.t.t.find_one({})),
234+
("listDatabases", lambda: client.list_databases()),
235+
("listCollections", lambda: client.t.list_collections()),
236+
("listIndexes", lambda: client.t.t.list_indexes()),
237+
]
237238

238-
# Disable failpoint.
239-
fail_command["mode"] = "off"
240-
set_fail_point(client, fail_command)
239+
for command_name, operation in commands:
240+
listener.reset()
241+
fail_command = {
242+
"configureFailPoint": "failCommand",
243+
"mode": {"times": 1},
244+
"data": {"failCommands": [command_name], "errorCode": 6},
245+
}
241246

242-
# Assert that both events occurred on the same session.
243-
lsids = [
244-
event.command["lsid"]
245-
for event in listener.started_events
246-
if event.command_name == "count"
247-
]
248-
self.assertEqual(len(lsids), 2)
249-
self.assertEqual(lsids[0], lsids[1])
247+
with self.fail_point(fail_command):
248+
operation()
249+
250+
# Assert that both events occurred on the same session.
251+
lsids = [
252+
event.command["lsid"]
253+
for event in listener.started_events
254+
if event.command_name == command_name
255+
]
256+
self.assertEqual(len(lsids), 2)
257+
self.assertEqual(lsids[0], lsids[1])
250258

251259

252260
if __name__ == "__main__":

0 commit comments

Comments
 (0)