Skip to content

Commit 8a3d700

Browse files
authored
feat: add support for retrying aborted partitioned DML statements (#66)
* feat: add support for retrying aborted partitioned dml statements * run blacken * use retry settings from config * fix imports from rebase Co-authored-by: larkee <[email protected]>
1 parent df4be7f commit 8a3d700

File tree

2 files changed

+83
-22
lines changed

2 files changed

+83
-22
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import threading
2222

2323
import google.auth.credentials
24+
from google.api_core.retry import if_exception_type
2425
from google.protobuf.struct_pb2 import Struct
2526
from google.cloud.exceptions import NotFound
27+
from google.api_core.exceptions import Aborted
2628
import six
2729

2830
# pylint: disable=ungrouped-imports
@@ -394,29 +396,36 @@ def execute_partitioned_dml(
394396

395397
metadata = _metadata_with_prefix(self.name)
396398

397-
with SessionCheckout(self._pool) as session:
399+
def execute_pdml():
400+
with SessionCheckout(self._pool) as session:
401+
402+
txn = api.begin_transaction(
403+
session.name, txn_options, metadata=metadata
404+
)
398405

399-
txn = api.begin_transaction(session.name, txn_options, metadata=metadata)
406+
txn_selector = TransactionSelector(id=txn.id)
407+
408+
restart = functools.partial(
409+
api.execute_streaming_sql,
410+
session.name,
411+
dml,
412+
transaction=txn_selector,
413+
params=params_pb,
414+
param_types=param_types,
415+
query_options=query_options,
416+
metadata=metadata,
417+
)
400418

401-
txn_selector = TransactionSelector(id=txn.id)
419+
iterator = _restart_on_unavailable(restart)
402420

403-
restart = functools.partial(
404-
api.execute_streaming_sql,
405-
session.name,
406-
dml,
407-
transaction=txn_selector,
408-
params=params_pb,
409-
param_types=param_types,
410-
query_options=query_options,
411-
metadata=metadata,
412-
)
421+
result_set = StreamedResultSet(iterator)
422+
list(result_set) # consume all partials
413423

414-
iterator = _restart_on_unavailable(restart)
424+
return result_set.stats.row_count_lower_bound
415425

416-
result_set = StreamedResultSet(iterator)
417-
list(result_set) # consume all partials
426+
retry_config = api._method_configs["ExecuteStreamingSql"].retry
418427

419-
return result_set.stats.row_count_lower_bound
428+
return _retry_on_aborted(execute_pdml, retry_config)()
420429

421430
def session(self, labels=None):
422431
"""Factory to create a session for this database.
@@ -976,3 +985,19 @@ def __init__(self, source_type, backup_info):
976985
@classmethod
977986
def from_pb(cls, pb):
978987
return cls(pb.source_type, pb.backup_info)
988+
989+
990+
def _retry_on_aborted(func, retry_config):
991+
"""Helper for :meth:`Database.execute_partitioned_dml`.
992+
993+
Wrap function in a Retry that will retry on Aborted exceptions
994+
with the retry config specified.
995+
996+
:type func: callable
997+
:param func: the function to be retried on Aborted exceptions
998+
999+
:type retry_config: Retry
1000+
:param retry_config: retry object with the settings to be used
1001+
"""
1002+
retry = retry_config.with_predicate(if_exception_type(Aborted))
1003+
return retry(func)

tests/unit/test_database.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class _BaseTest(unittest.TestCase):
5353
SESSION_ID = "session_id"
5454
SESSION_NAME = DATABASE_NAME + "/sessions/" + SESSION_ID
5555
TRANSACTION_ID = b"transaction_id"
56+
RETRY_TRANSACTION_ID = b"transaction_id_retry"
5657
BACKUP_ID = "backup_id"
5758
BACKUP_NAME = INSTANCE_NAME + "/backups/" + BACKUP_ID
5859

@@ -735,8 +736,10 @@ def test_drop_success(self):
735736
)
736737

737738
def _execute_partitioned_dml_helper(
738-
self, dml, params=None, param_types=None, query_options=None
739+
self, dml, params=None, param_types=None, query_options=None, retried=False
739740
):
741+
from google.api_core.exceptions import Aborted
742+
from google.api_core.retry import Retry
740743
from google.protobuf.struct_pb2 import Struct
741744
from google.cloud.spanner_v1.proto.result_set_pb2 import (
742745
PartialResultSet,
@@ -752,6 +755,10 @@ def _execute_partitioned_dml_helper(
752755
_merge_query_options,
753756
)
754757

758+
import collections
759+
760+
MethodConfig = collections.namedtuple("MethodConfig", ["retry"])
761+
755762
transaction_pb = TransactionPB(id=self.TRANSACTION_ID)
756763

757764
stats_pb = ResultSetStats(row_count_lower_bound=2)
@@ -765,8 +772,14 @@ def _execute_partitioned_dml_helper(
765772
pool.put(session)
766773
database = self._make_one(self.DATABASE_ID, instance, pool=pool)
767774
api = database._spanner_api = self._make_spanner_api()
768-
api.begin_transaction.return_value = transaction_pb
769-
api.execute_streaming_sql.return_value = iterator
775+
api._method_configs = {"ExecuteStreamingSql": MethodConfig(retry=Retry())}
776+
if retried:
777+
retry_transaction_pb = TransactionPB(id=self.RETRY_TRANSACTION_ID)
778+
api.begin_transaction.side_effect = [transaction_pb, retry_transaction_pb]
779+
api.execute_streaming_sql.side_effect = [Aborted("test"), iterator]
780+
else:
781+
api.begin_transaction.return_value = transaction_pb
782+
api.execute_streaming_sql.return_value = iterator
770783

771784
row_count = database.execute_partitioned_dml(
772785
dml, params, param_types, query_options
@@ -778,11 +791,15 @@ def _execute_partitioned_dml_helper(
778791
partitioned_dml=TransactionOptions.PartitionedDml()
779792
)
780793

781-
api.begin_transaction.assert_called_once_with(
794+
api.begin_transaction.assert_called_with(
782795
session.name,
783796
txn_options,
784797
metadata=[("google-cloud-resource-prefix", database.name)],
785798
)
799+
if retried:
800+
self.assertEqual(api.begin_transaction.call_count, 2)
801+
else:
802+
self.assertEqual(api.begin_transaction.call_count, 1)
786803

787804
if params:
788805
expected_params = Struct(
@@ -798,7 +815,7 @@ def _execute_partitioned_dml_helper(
798815
expected_query_options, query_options
799816
)
800817

801-
api.execute_streaming_sql.assert_called_once_with(
818+
api.execute_streaming_sql.assert_any_call(
802819
self.SESSION_NAME,
803820
dml,
804821
transaction=expected_transaction,
@@ -807,6 +824,22 @@ def _execute_partitioned_dml_helper(
807824
query_options=expected_query_options,
808825
metadata=[("google-cloud-resource-prefix", database.name)],
809826
)
827+
if retried:
828+
expected_retry_transaction = TransactionSelector(
829+
id=self.RETRY_TRANSACTION_ID
830+
)
831+
api.execute_streaming_sql.assert_called_with(
832+
self.SESSION_NAME,
833+
dml,
834+
transaction=expected_retry_transaction,
835+
params=expected_params,
836+
param_types=param_types,
837+
query_options=expected_query_options,
838+
metadata=[("google-cloud-resource-prefix", database.name)],
839+
)
840+
self.assertEqual(api.execute_streaming_sql.call_count, 2)
841+
else:
842+
self.assertEqual(api.execute_streaming_sql.call_count, 1)
810843

811844
def test_execute_partitioned_dml_wo_params(self):
812845
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM)
@@ -828,6 +861,9 @@ def test_execute_partitioned_dml_w_query_options(self):
828861
query_options=ExecuteSqlRequest.QueryOptions(optimizer_version="3"),
829862
)
830863

864+
def test_execute_partitioned_dml_wo_params_retry_aborted(self):
865+
self._execute_partitioned_dml_helper(dml=DML_WO_PARAM, retried=True)
866+
831867
def test_session_factory_defaults(self):
832868
from google.cloud.spanner_v1.session import Session
833869

0 commit comments

Comments
 (0)