11import random
2- import threading
32import time
43import unittest
54from threading import Thread
2120 switch_db ,
2221)
2322from mongoengine .pymongo_support import count_documents
24-
25- from .utils import (
23+ from tests .utils import (
2624 MongoDBTestCase ,
2725 requires_mongodb_gte_40 ,
2826 requires_mongodb_gte_44 ,
2927)
3028
3129
30+ class TestRollbackError (Exception ):
31+ pass
32+
33+
3234class TestableThread (Thread ):
3335 """
3436 Wrapper around `threading.Thread` that propagates exceptions.
@@ -525,7 +527,71 @@ class A(Document):
525527
526528 with run_in_transaction ():
527529 a_doc .update (name = "b" )
528- assert "b" == A .objects .get (id = a_doc .id ).name
530+ assert A .objects .get (id = a_doc .id ).name == "b"
531+ assert A .objects .count () == 1
532+
533+ assert A .objects .count () == 1
534+ assert A .objects .get (id = a_doc .id ).name == "b"
535+
536+ @requires_mongodb_gte_40
537+ def test_updating_a_document_within_a_transaction_that_fails (self ):
538+ connect ("mongoenginetest" )
539+
540+ class A (Document ):
541+ name = StringField ()
542+
543+ A .drop_collection ()
544+
545+ a_doc = A .objects .create (name = "a" )
546+
547+ with pytest .raises (TestRollbackError ):
548+ with run_in_transaction ():
549+ a_doc .update (name = "b" )
550+ assert A .objects .get (id = a_doc .id ).name == "b"
551+ raise TestRollbackError ()
552+
553+ assert A .objects .count () == 0
554+ assert A .objects .get (id = a_doc .id ).name == "a"
555+
556+ @requires_mongodb_gte_40
557+ def test_creating_a_document_within_a_transaction (self ):
558+ connect ("mongoenginetest" )
559+
560+ class A (Document ):
561+ name = StringField ()
562+
563+ A .drop_collection ()
564+
565+ with run_in_transaction ():
566+ a_doc = A .objects .create (name = "a" )
567+ another_doc = A (name = "b" ).save ()
568+ assert A .objects .get (id = a_doc .id ).name == "a"
569+ assert A .objects .get (id = another_doc .id ).name == "b"
570+ assert A .objects .count () == 2
571+
572+ assert A .objects .count () == 2
573+ assert A .objects .get (id = a_doc .id ).name == "a"
574+ assert A .objects .get (id = another_doc .id ).name == "b"
575+
576+ @requires_mongodb_gte_40
577+ def test_creating_a_document_within_a_transaction_that_fails (self ):
578+ connect ("mongoenginetest" )
579+
580+ class A (Document ):
581+ name = StringField ()
582+
583+ A .drop_collection ()
584+
585+ with pytest .raises (TestRollbackError ):
586+ with run_in_transaction ():
587+ a_doc = A .objects .create (name = "a" )
588+ another_doc = A (name = "b" ).save ()
589+ assert A .objects .get (id = a_doc .id ).name == "a"
590+ assert A .objects .get (id = another_doc .id ).name == "b"
591+ assert A .objects .count () == 2
592+ raise TestRollbackError ()
593+
594+ assert A .objects .count () == 0
529595
530596 @requires_mongodb_gte_40
531597 def test_transaction_updates_across_databases (self ):
@@ -640,17 +706,23 @@ class B(Document):
640706 B .drop_collection ()
641707 b_doc = B .objects .create (name = "b" )
642708
643- try :
709+ with pytest . raises ( TestRollbackError ) :
644710 with run_in_transaction ():
645711 a_doc .update (name = "trx-parent" )
646- with run_in_transaction ():
647- b_doc .update (name = "trx-child" )
648- raise Exception
649- except Exception :
650- pass
712+ try :
713+ with run_in_transaction ():
714+ b_doc .update (name = "trx-child" )
715+ raise TestRollbackError ()
716+ except TestRollbackError as exc :
717+ # at this stage, the parent transaction is still there
718+ assert A .objects .get (id = a_doc .id ).name == "trx-parent"
719+ raise exc
720+ else :
721+ # makes sure it enters the except above
722+ assert False
651723
652- assert "a" == A .objects .get (id = a_doc .id ).name
653- assert "b" == B .objects .get (id = b_doc .id ).name
724+ assert A .objects .get (id = a_doc .id ).name == "a"
725+ assert B .objects .get (id = b_doc .id ).name == "b"
654726
655727 @requires_mongodb_gte_40
656728 def test_exception_in_parent_of_nested_transaction_after_child_completed_only_rolls_parent_back (
@@ -670,26 +742,25 @@ class B(Document):
670742 B .drop_collection ()
671743 b_doc = B .objects .create (name = "b" )
672744
673- class TestExc (Exception ):
674- pass
675-
676745 def run_tx ():
677746 try :
678747 with run_in_transaction ():
679748 a_doc .update (name = "trx-parent" )
680749 with run_in_transaction ():
681750 b_doc .update (name = "trx-child" )
682- raise TestExc
683- except TestExc :
751+
752+ raise TestRollbackError ()
753+
754+ except TestRollbackError :
684755 pass
685- except OperationError as op_failure :
686- """
687- See thread safety test below for more details about TransientTransctionError handling
688- """
689- if "TransientTransactionError" in str (op_failure ):
690- run_tx ()
691- else :
692- raise op_failure
756+ # except OperationError as op_failure:
757+ # """
758+ # See thread safety test below for more details about TransientTransactionError handling
759+ # """
760+ # if "TransientTransactionError" in str(op_failure):
761+ # run_tx()
762+ # else:
763+ # raise op_failure
693764
694765 run_tx ()
695766 assert "a" == A .objects .get (id = a_doc .id ).name
@@ -702,11 +773,12 @@ def test_nested_transactions_create_and_release_sessions_accordingly(self):
702773 s1 = _get_session ()
703774 with run_in_transaction ():
704775 s2 = _get_session ()
705- assert s1 != s2
776+ assert s1 is not s2
706777 with run_in_transaction ():
707778 pass
708- assert s2 == _get_session ()
709- assert s1 == _get_session ()
779+ assert _get_session () is s2
780+ assert _get_session () is s1
781+
710782 assert _get_session () is None
711783
712784 @requires_mongodb_gte_40
@@ -740,24 +812,21 @@ class A(Document):
740812 # Ensure the collection is created
741813 A .objects .create (i = 0 )
742814
743- class TestExc (Exception ):
744- pass
745-
746815 def thread_fn (idx ):
747816 # Open the transaction at some unknown interval
748- time .sleep (random .uniform (0.01 , 0.1 ))
817+ time .sleep (random .uniform (0.1 , 0.5 ))
749818 try :
750819 with run_in_transaction ():
751820 a = A .objects .get (i = idx )
752821 a .i = idx * 10
753822 # Save at some unknown interval
754- time .sleep (random .uniform (0.01 , 0.1 ))
823+ time .sleep (random .uniform (0.1 , 0.5 ))
755824 a .save ()
756825
757826 # Force roll backs for the even runs...
758827 if idx % 2 == 0 :
759- raise TestExc
760- except TestExc :
828+ raise TestRollbackError ()
829+ except TestRollbackError :
761830 pass
762831 except pymongo .errors .OperationFailure as op_failure :
763832 """
@@ -775,7 +844,7 @@ def thread_fn(idx):
775844 else :
776845 raise op_failure
777846
778- for r in range (10 ):
847+ for r in range (5 ):
779848 """
780849 Threads & randomization are tricky - run it multiple times
781850 """
@@ -784,14 +853,13 @@ def thread_fn(idx):
784853 A .objects .all ().delete ()
785854
786855 # Prepopulate the data for reads
787- thread_count = 10
856+ thread_count = 20
788857 for i in range (thread_count ):
789858 A .objects .create (i = i )
790859
791860 # Prime the threads
792861 threads = [
793- threading .Thread (target = thread_fn , args = (i ,))
794- for i in range (thread_count )
862+ TestableThread (target = thread_fn , args = (i ,)) for i in range (thread_count )
795863 ]
796864 for t in threads :
797865 t .start ()
@@ -805,7 +873,7 @@ def thread_fn(idx):
805873 expected_sum += i
806874 else :
807875 expected_sum += i * 10
808- assert expected_sum == 270
876+ assert expected_sum == 1090
809877 assert expected_sum == sum (a .i for a in A .objects .all ())
810878
811879
0 commit comments