|
35 | 35 | from gridfs import GridFS, GridFSBucket
|
36 | 36 | from pymongo import WriteConcern, client_session
|
37 | 37 | from pymongo.client_session import TransactionOptions
|
| 38 | +from pymongo.command_cursor import CommandCursor |
| 39 | +from pymongo.cursor import Cursor |
38 | 40 | from pymongo.errors import (
|
39 | 41 | CollectionInvalid,
|
40 | 42 | ConfigurationError,
|
@@ -351,6 +353,42 @@ def test_transaction_starts_with_batched_write(self):
|
351 | 353 | self.assertEqual(txn_number, event.command["txnNumber"])
|
352 | 354 | self.assertEqual(48, coll.count_documents({}))
|
353 | 355 |
|
| 356 | + @client_context.require_transactions |
| 357 | + def test_transaction_direct_connection(self): |
| 358 | + client = single_client() |
| 359 | + self.addCleanup(client.close) |
| 360 | + coll = client.pymongo_test.test |
| 361 | + |
| 362 | + # Make sure the collection exists. |
| 363 | + coll.insert_one({}) |
| 364 | + self.assertEqual(client.topology_description.topology_type_name, "Single") |
| 365 | + ops = [ |
| 366 | + (coll.bulk_write, [[InsertOne({})]]), |
| 367 | + (coll.insert_one, [{}]), |
| 368 | + (coll.insert_many, [[{}, {}]]), |
| 369 | + (coll.replace_one, [{}, {}]), |
| 370 | + (coll.update_one, [{}, {"$set": {"a": 1}}]), |
| 371 | + (coll.update_many, [{}, {"$set": {"a": 1}}]), |
| 372 | + (coll.delete_one, [{}]), |
| 373 | + (coll.delete_many, [{}]), |
| 374 | + (coll.find_one_and_replace, [{}, {}]), |
| 375 | + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), |
| 376 | + (coll.find_one_and_delete, [{}, {}]), |
| 377 | + (coll.find_one, [{}]), |
| 378 | + (coll.count_documents, [{}]), |
| 379 | + (coll.distinct, ["foo"]), |
| 380 | + (coll.aggregate, [[]]), |
| 381 | + (coll.find, [{}]), |
| 382 | + (coll.aggregate_raw_batches, [[]]), |
| 383 | + (coll.find_raw_batches, [{}]), |
| 384 | + (coll.database.command, ["find", coll.name]), |
| 385 | + ] |
| 386 | + for f, args in ops: |
| 387 | + with client.start_session() as s, s.start_transaction(): |
| 388 | + res = f(*args, session=s) |
| 389 | + if isinstance(res, (CommandCursor, Cursor)): |
| 390 | + list(res) |
| 391 | + |
354 | 392 |
|
355 | 393 | class PatchSessionTimeout(object):
|
356 | 394 | """Patches the client_session's with_transaction timeout for testing."""
|
|
0 commit comments