@@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase):
5353 SESSION_ID = "session_id"
5454 SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
5555 TRANSACTION_ID = b"transaction_id"
56+ RETRY_TRANSACTION_ID = b"transaction_id_retry"
5657 BACKUP_ID = "backup_id"
5758 BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID
5859
@@ -735,8 +736,10 @@ def test_drop_success(self):
735736 )
736737
737738 def _execute_partitioned_dml_helper (
738- self , dml , params = None , param_types = None , query_options = None
739+ self , dml , params = None , param_types = None , query_options = None , retried = False
739740 ):
741+ from google .api_core .exceptions import Aborted
742+ from google .api_core .retry import Retry
740743 from google .protobuf .struct_pb2 import Struct
741744 from google .cloud .spanner_v1 .proto .result_set_pb2 import (
742745 PartialResultSet ,
@@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper(
752755 _merge_query_options ,
753756 )
754757
758+ import collections
759+
760+ MethodConfig = collections .namedtuple ("MethodConfig" , ["retry" ])
761+
755762 transaction_pb = TransactionPB (id = self .TRANSACTION_ID )
756763
757764 stats_pb = ResultSetStats (row_count_lower_bound = 2 )
@@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper(
765772 pool .put (session )
766773 database = self ._make_one (self .DATABASE_ID , instance , pool = pool )
767774 api = database ._spanner_api = self ._make_spanner_api ()
768- api .begin_transaction .return_value = transaction_pb
769- api .execute_streaming_sql .return_value = iterator
775+ api ._method_configs = {"ExecuteStreamingSql" : MethodConfig (retry = Retry ())}
776+ if retried :
777+ retry_transaction_pb = TransactionPB (id = self .RETRY_TRANSACTION_ID )
778+ api .begin_transaction .side_effect = [transaction_pb , retry_transaction_pb ]
779+ api .execute_streaming_sql .side_effect = [Aborted ("test" ), iterator ]
780+ else :
781+ api .begin_transaction .return_value = transaction_pb
782+ api .execute_streaming_sql .return_value = iterator
770783
771784 row_count = database .execute_partitioned_dml (
772785 dml , params , param_types , query_options
@@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper(
778791 partitioned_dml = TransactionOptions .PartitionedDml ()
779792 )
780793
781- api .begin_transaction .assert_called_once_with (
794+ api .begin_transaction .assert_called_with (
782795 session .name ,
783796 txn_options ,
784797 metadata = [("google-cloud-resource-prefix" , database .name )],
785798 )
799+ if retried :
800+ self .assertEqual (api .begin_transaction .call_count , 2 )
801+ else :
802+ self .assertEqual (api .begin_transaction .call_count , 1 )
786803
787804 if params :
788805 expected_params = Struct (
@@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper(
798815 expected_query_options , query_options
799816 )
800817
801- api .execute_streaming_sql .assert_called_once_with (
818+ api .execute_streaming_sql .assert_any_call (
802819 self .SESSION_NAME ,
803820 dml ,
804821 transaction = expected_transaction ,
@@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper(
807824 query_options = expected_query_options ,
808825 metadata = [("google-cloud-resource-prefix" , database .name )],
809826 )
827+ if retried :
828+ expected_retry_transaction = TransactionSelector (
829+ id = self .RETRY_TRANSACTION_ID
830+ )
831+ api .execute_streaming_sql .assert_called_with (
832+ self .SESSION_NAME ,
833+ dml ,
834+ transaction = expected_retry_transaction ,
835+ params = expected_params ,
836+ param_types = param_types ,
837+ query_options = expected_query_options ,
838+ metadata = [("google-cloud-resource-prefix" , database .name )],
839+ )
840+ self .assertEqual (api .execute_streaming_sql .call_count , 2 )
841+ else :
842+ self .assertEqual (api .execute_streaming_sql .call_count , 1 )
810843
811844 def test_execute_partitioned_dml_wo_params (self ):
812845 self ._execute_partitioned_dml_helper (dml = DML_WO_PARAM )
@@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self):
828861 query_options = ExecuteSqlRequest .QueryOptions (optimizer_version = "3" ),
829862 )
830863
864+ def test_execute_partitioned_dml_wo_params_retry_aborted (self ):
865+ self ._execute_partitioned_dml_helper (dml = DML_WO_PARAM , retried = True )
866+
831867 def test_session_factory_defaults (self ):
832868 from google .cloud .spanner_v1 .session import Session
833869
0 commit comments