Skip to content

Commit fe9e42a

Browse files
committed
fix minor things and add increase tests cov for tx support
1 parent c33cd1d commit fe9e42a

File tree

5 files changed

+114
-42
lines changed

5 files changed

+114
-42
lines changed

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Changelog
77
Development
88
===========
99
- (Fill this out as you fix issues and develop your features).
10+
- Add support for transaction through run_in_transaction #2569
1011

1112
Changes in 0.29.0
1213
=================

mongoengine/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def append(self, session):
484484

485485
def get_current(self):
486486
if len(self.sessions):
487-
return self.sessions[len(self.sessions) - 1]
487+
return self.sessions[-1]
488488

489489
def clear_current(self):
490490
if len(self.sessions):

mongoengine/pymongo_support.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def count_documents(
4040
# count_documents appeared in pymongo 3.7
4141
if PYMONGO_VERSION >= (3, 7):
4242
try:
43-
if not filter and set(kwargs) <= {"max_time_ms"}:
43+
is_active_session = connection._get_session() is not None
44+
if not filter and set(kwargs) <= {"max_time_ms"} and not is_active_session:
4445
# when no filter is provided, estimated_document_count
4546
# is a lot faster as it uses the collection metadata
4647
return collection.estimated_document_count(**kwargs)

tests/test_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,8 @@ def test_multiple_connection_settings(self):
647647
# Purposely not catching exception to fail test if thrown.
648648
mongo_connections["t1"].server_info()
649649
mongo_connections["t2"].server_info()
650+
assert mongo_connections["t1"].address[0] == "localhost"
651+
assert mongo_connections["t2"].address[0] == "127.0.0.1"
650652
assert mongo_connections["t1"] is not mongo_connections["t2"]
651653

652654
def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self):

tests/test_context_managers.py

Lines changed: 108 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import random
2-
import threading
32
import time
43
import unittest
54
from threading import Thread
@@ -21,14 +20,17 @@
2120
switch_db,
2221
)
2322
from mongoengine.pymongo_support import count_documents
24-
25-
from .utils import (
23+
from tests.utils import (
2624
MongoDBTestCase,
2725
requires_mongodb_gte_40,
2826
requires_mongodb_gte_44,
2927
)
3028

3129

30+
class TestRollbackError(Exception):
31+
pass
32+
33+
3234
class TestableThread(Thread):
3335
"""
3436
Wrapper around `threading.Thread` that propagates exceptions.
@@ -525,7 +527,71 @@ class A(Document):
525527

526528
with run_in_transaction():
527529
a_doc.update(name="b")
528-
assert "b" == A.objects.get(id=a_doc.id).name
530+
assert A.objects.get(id=a_doc.id).name == "b"
531+
assert A.objects.count() == 1
532+
533+
assert A.objects.count() == 1
534+
assert A.objects.get(id=a_doc.id).name == "b"
535+
536+
@requires_mongodb_gte_40
537+
def test_updating_a_document_within_a_transaction_that_fails(self):
538+
connect("mongoenginetest")
539+
540+
class A(Document):
541+
name = StringField()
542+
543+
A.drop_collection()
544+
545+
a_doc = A.objects.create(name="a")
546+
547+
with pytest.raises(TestRollbackError):
548+
with run_in_transaction():
549+
a_doc.update(name="b")
550+
assert A.objects.get(id=a_doc.id).name == "b"
551+
raise TestRollbackError()
552+
553+
assert A.objects.count() == 0
554+
assert A.objects.get(id=a_doc.id).name == "a"
555+
556+
@requires_mongodb_gte_40
557+
def test_creating_a_document_within_a_transaction(self):
558+
connect("mongoenginetest")
559+
560+
class A(Document):
561+
name = StringField()
562+
563+
A.drop_collection()
564+
565+
with run_in_transaction():
566+
a_doc = A.objects.create(name="a")
567+
another_doc = A(name="b").save()
568+
assert A.objects.get(id=a_doc.id).name == "a"
569+
assert A.objects.get(id=another_doc.id).name == "b"
570+
assert A.objects.count() == 2
571+
572+
assert A.objects.count() == 2
573+
assert A.objects.get(id=a_doc.id).name == "a"
574+
assert A.objects.get(id=another_doc.id).name == "b"
575+
576+
@requires_mongodb_gte_40
577+
def test_creating_a_document_within_a_transaction_that_fails(self):
578+
connect("mongoenginetest")
579+
580+
class A(Document):
581+
name = StringField()
582+
583+
A.drop_collection()
584+
585+
with pytest.raises(TestRollbackError):
586+
with run_in_transaction():
587+
a_doc = A.objects.create(name="a")
588+
another_doc = A(name="b").save()
589+
assert A.objects.get(id=a_doc.id).name == "a"
590+
assert A.objects.get(id=another_doc.id).name == "b"
591+
assert A.objects.count() == 2
592+
raise TestRollbackError()
593+
594+
assert A.objects.count() == 0
529595

530596
@requires_mongodb_gte_40
531597
def test_transaction_updates_across_databases(self):
@@ -640,17 +706,23 @@ class B(Document):
640706
B.drop_collection()
641707
b_doc = B.objects.create(name="b")
642708

643-
try:
709+
with pytest.raises(TestRollbackError):
644710
with run_in_transaction():
645711
a_doc.update(name="trx-parent")
646-
with run_in_transaction():
647-
b_doc.update(name="trx-child")
648-
raise Exception
649-
except Exception:
650-
pass
712+
try:
713+
with run_in_transaction():
714+
b_doc.update(name="trx-child")
715+
raise TestRollbackError()
716+
except TestRollbackError as exc:
717+
# at this stage, the parent transaction is still there
718+
assert A.objects.get(id=a_doc.id).name == "trx-parent"
719+
raise exc
720+
else:
721+
# makes sure it enters the except above
722+
assert False
651723

652-
assert "a" == A.objects.get(id=a_doc.id).name
653-
assert "b" == B.objects.get(id=b_doc.id).name
724+
assert A.objects.get(id=a_doc.id).name == "a"
725+
assert B.objects.get(id=b_doc.id).name == "b"
654726

655727
@requires_mongodb_gte_40
656728
def test_exception_in_parent_of_nested_transaction_after_child_completed_only_rolls_parent_back(
@@ -670,26 +742,25 @@ class B(Document):
670742
B.drop_collection()
671743
b_doc = B.objects.create(name="b")
672744

673-
class TestExc(Exception):
674-
pass
675-
676745
def run_tx():
677746
try:
678747
with run_in_transaction():
679748
a_doc.update(name="trx-parent")
680749
with run_in_transaction():
681750
b_doc.update(name="trx-child")
682-
raise TestExc
683-
except TestExc:
751+
752+
raise TestRollbackError()
753+
754+
except TestRollbackError:
684755
pass
685-
except OperationError as op_failure:
686-
"""
687-
See thread safety test below for more details about TransientTransctionError handling
688-
"""
689-
if "TransientTransactionError" in str(op_failure):
690-
run_tx()
691-
else:
692-
raise op_failure
756+
# except OperationError as op_failure:
757+
# """
758+
# See thread safety test below for more details about TransientTransactionError handling
759+
# """
760+
# if "TransientTransactionError" in str(op_failure):
761+
# run_tx()
762+
# else:
763+
# raise op_failure
693764

694765
run_tx()
695766
assert "a" == A.objects.get(id=a_doc.id).name
@@ -702,11 +773,12 @@ def test_nested_transactions_create_and_release_sessions_accordingly(self):
702773
s1 = _get_session()
703774
with run_in_transaction():
704775
s2 = _get_session()
705-
assert s1 != s2
776+
assert s1 is not s2
706777
with run_in_transaction():
707778
pass
708-
assert s2 == _get_session()
709-
assert s1 == _get_session()
779+
assert _get_session() is s2
780+
assert _get_session() is s1
781+
710782
assert _get_session() is None
711783

712784
@requires_mongodb_gte_40
@@ -740,24 +812,21 @@ class A(Document):
740812
# Ensure the collection is created
741813
A.objects.create(i=0)
742814

743-
class TestExc(Exception):
744-
pass
745-
746815
def thread_fn(idx):
747816
# Open the transaction at some unknown interval
748-
time.sleep(random.uniform(0.01, 0.1))
817+
time.sleep(random.uniform(0.1, 0.5))
749818
try:
750819
with run_in_transaction():
751820
a = A.objects.get(i=idx)
752821
a.i = idx * 10
753822
# Save at some unknown interval
754-
time.sleep(random.uniform(0.01, 0.1))
823+
time.sleep(random.uniform(0.1, 0.5))
755824
a.save()
756825

757826
# Force roll backs for the even runs...
758827
if idx % 2 == 0:
759-
raise TestExc
760-
except TestExc:
828+
raise TestRollbackError()
829+
except TestRollbackError:
761830
pass
762831
except pymongo.errors.OperationFailure as op_failure:
763832
"""
@@ -775,7 +844,7 @@ def thread_fn(idx):
775844
else:
776845
raise op_failure
777846

778-
for r in range(10):
847+
for r in range(5):
779848
"""
780849
Threads & randomization are tricky - run it multiple times
781850
"""
@@ -784,14 +853,13 @@ def thread_fn(idx):
784853
A.objects.all().delete()
785854

786855
# Prepopulate the data for reads
787-
thread_count = 10
856+
thread_count = 20
788857
for i in range(thread_count):
789858
A.objects.create(i=i)
790859

791860
# Prime the threads
792861
threads = [
793-
threading.Thread(target=thread_fn, args=(i,))
794-
for i in range(thread_count)
862+
TestableThread(target=thread_fn, args=(i,)) for i in range(thread_count)
795863
]
796864
for t in threads:
797865
t.start()
@@ -805,7 +873,7 @@ def thread_fn(idx):
805873
expected_sum += i
806874
else:
807875
expected_sum += i * 10
808-
assert expected_sum == 270
876+
assert expected_sum == 1090
809877
assert expected_sum == sum(a.i for a in A.objects.all())
810878

811879

0 commit comments

Comments
 (0)