Skip to content

Commit de31667

Browse files
committed
Add RPC Retries via interceptor
Add RPC retries via an interceptor, with exponential backoff matching both the Go and Java clients. The approach here differs from both however in two main ways: 1. Adding these via interceptor makes it implicit, while both the Go and Java client require it to be explicit. 2. The specific requests to retry are based on GRPC error codes, rather than explicitly listing non-retryable errors and retrying everything by default. This seems like a more sustainable approach, since nearly every error type is non-retryable. A newly introduced error type would require a client update to mark it non-retryable before it could safely be used. Any time the python client doesn't recognize an error it gets mapped to just CadenceError, so new errors can safely be added.
1 parent c31793f commit de31667

File tree

4 files changed

+264
-18
lines changed

4 files changed

+264
-18
lines changed

cadence/_internal/rpc/retry.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
from typing import Callable, Any
4+
5+
from grpc import StatusCode
6+
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
7+
8+
from cadence.error import CadenceError, EntityNotExistsError
9+
10+
RETRYABLE_CODES = {
11+
StatusCode.INTERNAL,
12+
StatusCode.RESOURCE_EXHAUSTED,
13+
StatusCode.ABORTED,
14+
StatusCode.UNAVAILABLE
15+
}
16+
17+
# No expiration interval, use the GRPC timeout value instead
18+
@dataclass
19+
class ExponentialRetryPolicy:
20+
initial_interval: float
21+
backoff_coefficient: float
22+
max_interval: float
23+
max_attempts: float
24+
25+
def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None:
26+
if elapsed >= expiration:
27+
return None
28+
if self.max_attempts != 0 and attempts >= self.max_attempts:
29+
return None
30+
31+
backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval)
32+
if (elapsed + backoff) >= expiration:
33+
return None
34+
35+
return backoff
36+
37+
DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0)
38+
39+
class RetryInterceptor(UnaryUnaryClientInterceptor):
40+
def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY):
41+
super().__init__()
42+
self._retry_policy = retry_policy
43+
44+
async def intercept_unary_unary(
45+
self,
46+
continuation: Callable[[ClientCallDetails, Any], Any],
47+
client_call_details: ClientCallDetails,
48+
request: Any
49+
) -> Any:
50+
loop = asyncio.get_running_loop()
51+
expiration_interval = client_call_details.timeout
52+
start_time = loop.time()
53+
deadline = start_time + expiration_interval
54+
rpc_call = await continuation(client_call_details, request)
55+
try:
56+
# Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall
57+
return await rpc_call
58+
except CadenceError as e:
59+
err = e
60+
attempts = 1
61+
while is_retryable(err, client_call_details):
62+
elapsed = loop.time() - start_time
63+
backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval)
64+
if backoff is None:
65+
break
66+
67+
await asyncio.sleep(backoff)
68+
remaining = deadline - loop.time()
69+
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
70+
# noinspection PyProtectedMember
71+
call_details = client_call_details._replace(timeout=remaining)
72+
# Start a retry
73+
attempts += 1
74+
rpc_call = await continuation(call_details, request)
75+
try:
76+
# Return the result if it's a success
77+
return await rpc_call
78+
except CadenceError as e:
79+
err = e
80+
81+
# On policy expiration, return the most recent UnaryUnaryCall. It has the error we want
82+
return rpc_call
83+
84+
85+
86+
def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool:
87+
# Handle requests to the passive side, matching the Go and Java Clients
88+
if call_details.method == b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory' and isinstance(err, EntityNotExistsError):
89+
return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster
90+
91+
return err.code in RETRYABLE_CODES

cadence/_internal/rpc/yarpc.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
1-
import collections
21
from typing import Any, Callable
32

43
from grpc.aio import Metadata
54
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
65

76

8-
class _ClientCallDetails(
9-
collections.namedtuple(
10-
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready")
11-
),
12-
ClientCallDetails,
13-
):
14-
pass
15-
167
SERVICE_KEY = "rpc-service"
178
CALLER_KEY = "rpc-caller"
189
ENCODING_KEY = "rpc-encoding"
@@ -42,11 +33,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall
4233
else:
4334
metadata += self._metadata
4435

45-
return _ClientCallDetails(
46-
method=client_call_details.method,
47-
# YARPC seems to require a TTL value
48-
timeout=client_call_details.timeout or 60.0,
49-
metadata=metadata,
50-
credentials=client_call_details.credentials,
51-
wait_for_ready=client_call_details.wait_for_ready,
52-
)
36+
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
37+
# noinspection PyProtectedMember
38+
return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0)

cadence/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from grpc import ChannelCredentials, Compression
66

77
from cadence._internal.rpc.error import CadenceErrorInterceptor
8+
from cadence._internal.rpc.retry import RetryInterceptor
89
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
910
from cadence.api.v1.service_domain_pb2_grpc import DomainAPIStub
1011
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
@@ -91,8 +92,9 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
9192

9293
def _create_channel(options: ClientOptions) -> Channel:
9394
interceptors = list(options["interceptors"])
94-
interceptors.append(CadenceErrorInterceptor())
9595
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
96+
interceptors.append(RetryInterceptor())
97+
interceptors.append(CadenceErrorInterceptor())
9698

9799
if options["credentials"]:
98100
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from concurrent import futures
2+
from typing import Tuple, Type
3+
4+
import pytest
5+
from google.protobuf import any_pb2
6+
from google.rpc import status_pb2, code_pb2
7+
from grpc import server
8+
from grpc.aio import insecure_channel
9+
from grpc_status.rpc_status import to_status
10+
11+
from cadence._internal.rpc.error import CadenceErrorInterceptor
12+
from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc
13+
14+
from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor
15+
from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \
16+
DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest
17+
from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError
18+
19+
simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6)
20+
21+
@pytest.mark.parametrize(
22+
"policy,params,expected",
23+
[
24+
pytest.param(
25+
simple_policy, (1, 0.0, 100.0), 1, id="happy path"
26+
),
27+
pytest.param(
28+
simple_policy, (2, 0.0, 100.0), 2, id="second attempt"
29+
),
30+
pytest.param(
31+
simple_policy, (3, 0.0, 100.0), 4, id="third attempt"
32+
),
33+
pytest.param(
34+
simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval"
35+
),
36+
pytest.param(
37+
simple_policy, (6, 0.0, 100.0), None, id="out of attempts"
38+
),
39+
pytest.param(
40+
simple_policy, (1, 100.0, 100.0), None, id="timeout"
41+
),
42+
pytest.param(
43+
simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout"
44+
),
45+
pytest.param(
46+
ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries"
47+
),
48+
]
49+
)
50+
def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None):
51+
assert policy.next_delay(*params) == expected
52+
53+
54+
class FakeService(service_workflow_pb2_grpc.WorkflowAPIServicer):
55+
def __init__(self) -> None:
56+
super().__init__()
57+
self.port = None
58+
self.counter = 0
59+
60+
# Retryable only because it's GetWorkflowExecutionHistory
61+
def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context):
62+
self.counter += 1
63+
64+
detail = any_pb2.Any()
65+
detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active"))
66+
status_proto = status_pb2.Status(
67+
code=code_pb2.NOT_FOUND,
68+
message="message",
69+
details=[detail],
70+
)
71+
context.abort_with_status(to_status(status_proto))
72+
# Unreachable
73+
74+
75+
# Not retryable
76+
def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context):
77+
self.counter += 1
78+
79+
if request.domain == "success":
80+
return DescribeWorkflowExecutionResponse()
81+
elif request.domain == "retryable":
82+
code = code_pb2.RESOURCE_EXHAUSTED
83+
elif request.domain == "maybe later":
84+
if self.counter >= 3:
85+
return DescribeWorkflowExecutionResponse()
86+
87+
code = code_pb2.RESOURCE_EXHAUSTED
88+
else:
89+
code = code_pb2.PERMISSION_DENIED
90+
91+
detail = any_pb2.Any()
92+
detail.Pack(error_pb2.FeatureNotEnabledError(feature_flag="the flag"))
93+
status_proto = status_pb2.Status(
94+
code=code,
95+
message="message",
96+
details=[detail],
97+
)
98+
context.abort_with_status(to_status(status_proto))
99+
# Unreachable
100+
101+
102+
@pytest.fixture(scope="module")
103+
def fake_service():
104+
fake = FakeService()
105+
sync_server = server(futures.ThreadPoolExecutor(max_workers=1))
106+
service_workflow_pb2_grpc.add_WorkflowAPIServicer_to_server(fake, sync_server)
107+
fake.port = sync_server.add_insecure_port("[::]:0")
108+
sync_server.start()
109+
yield fake
110+
sync_server.stop(grace=None)
111+
112+
TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10)
113+
114+
@pytest.mark.usefixtures("fake_service")
115+
@pytest.mark.parametrize(
116+
"case,expected_calls,expected_err",
117+
[
118+
pytest.param(
119+
"success", 1, None, id="happy path"
120+
),
121+
pytest.param(
122+
"maybe later", 3, None, id="retries then success"
123+
),
124+
pytest.param(
125+
"not retryable", 1, FeatureNotEnabledError, id="not retryable"
126+
),
127+
pytest.param(
128+
"retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted"
129+
),
130+
131+
]
132+
)
133+
@pytest.mark.asyncio
134+
async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]):
135+
fake_service.counter = 0
136+
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
137+
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
138+
if expected_err:
139+
with pytest.raises(expected_err):
140+
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)
141+
else:
142+
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)
143+
144+
assert fake_service.counter == expected_calls
145+
146+
@pytest.mark.usefixtures("fake_service")
147+
@pytest.mark.parametrize(
148+
"case,expected_calls",
149+
[
150+
pytest.param(
151+
"active", 1, id="not retryable"
152+
),
153+
pytest.param(
154+
"not active", TEST_POLICY.max_attempts, id="retries exhausted"
155+
),
156+
157+
]
158+
)
159+
@pytest.mark.asyncio
160+
async def test_workflow_history(fake_service, case: str, expected_calls: int):
161+
fake_service.counter = 0
162+
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
163+
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
164+
with pytest.raises(EntityNotExistsError):
165+
await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10)
166+
167+
assert fake_service.counter == expected_calls

0 commit comments

Comments
 (0)