Skip to content

Commit a5a257c

Browse files
authored
Add RPC Retries via interceptor (#24)
1 parent 959f812 commit a5a257c

File tree

4 files changed

+260
-18
lines changed

4 files changed

+260
-18
lines changed

cadence/_internal/rpc/retry.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import asyncio
2+
from dataclasses import dataclass
3+
from typing import Callable, Any
4+
5+
from grpc import StatusCode
6+
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
7+
8+
from cadence.error import CadenceError, EntityNotExistsError
9+
10+
RETRYABLE_CODES = {
11+
StatusCode.INTERNAL,
12+
StatusCode.RESOURCE_EXHAUSTED,
13+
StatusCode.ABORTED,
14+
StatusCode.UNAVAILABLE
15+
}
16+
17+
# No expiration interval, use the GRPC timeout value instead
18+
@dataclass
19+
class ExponentialRetryPolicy:
20+
initial_interval: float
21+
backoff_coefficient: float
22+
max_interval: float
23+
max_attempts: float
24+
25+
def next_delay(self, attempts: int, elapsed: float, expiration: float) -> float | None:
26+
if elapsed >= expiration:
27+
return None
28+
if self.max_attempts != 0 and attempts >= self.max_attempts:
29+
return None
30+
31+
backoff = min(self.initial_interval * pow(self.backoff_coefficient, attempts-1), self.max_interval)
32+
if (elapsed + backoff) >= expiration:
33+
return None
34+
35+
return backoff
36+
37+
DEFAULT_RETRY_POLICY = ExponentialRetryPolicy(initial_interval=0.02, backoff_coefficient=1.2, max_interval=6, max_attempts=0)
38+
GET_WORKFLOW_HISTORY = b'/uber.cadence.api.v1.WorkflowAPI/GetWorkflowExecutionHistory'
39+
40+
class RetryInterceptor(UnaryUnaryClientInterceptor):
41+
def __init__(self, retry_policy: ExponentialRetryPolicy = DEFAULT_RETRY_POLICY):
42+
super().__init__()
43+
self._retry_policy = retry_policy
44+
45+
async def intercept_unary_unary(
46+
self,
47+
continuation: Callable[[ClientCallDetails, Any], Any],
48+
client_call_details: ClientCallDetails,
49+
request: Any
50+
) -> Any:
51+
loop = asyncio.get_running_loop()
52+
expiration_interval = client_call_details.timeout
53+
start_time = loop.time()
54+
deadline = start_time + expiration_interval
55+
56+
attempts = 0
57+
while True:
58+
remaining = deadline - loop.time()
59+
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
60+
# noinspection PyProtectedMember
61+
call_details = client_call_details._replace(timeout=remaining)
62+
rpc_call = await continuation(call_details, request)
63+
try:
64+
# Return the result directly if success. GRPC will wrap it back into a UnaryUnaryCall
65+
return await rpc_call
66+
except CadenceError as e:
67+
err = e
68+
69+
attempts += 1
70+
elapsed = loop.time() - start_time
71+
backoff = self._retry_policy.next_delay(attempts, elapsed, expiration_interval)
72+
if not is_retryable(err, client_call_details) or backoff is None:
73+
break
74+
75+
await asyncio.sleep(backoff)
76+
77+
# On policy expiration, return the most recent UnaryUnaryCall. It has the error we want
78+
return rpc_call
79+
80+
81+
82+
def is_retryable(err: CadenceError, call_details: ClientCallDetails) -> bool:
83+
# Handle requests to the passive side, matching the Go and Java Clients
84+
if call_details.method == GET_WORKFLOW_HISTORY and isinstance(err, EntityNotExistsError):
85+
return err.active_cluster is not None and err.current_cluster is not None and err.active_cluster != err.current_cluster
86+
87+
return err.code in RETRYABLE_CODES

cadence/_internal/rpc/yarpc.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
1-
import collections
21
from typing import Any, Callable
32

43
from grpc.aio import Metadata
54
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails
65

76

8-
class _ClientCallDetails(
9-
collections.namedtuple(
10-
"_ClientCallDetails", ("method", "timeout", "metadata", "credentials", "wait_for_ready")
11-
),
12-
ClientCallDetails,
13-
):
14-
pass
15-
167
SERVICE_KEY = "rpc-service"
178
CALLER_KEY = "rpc-caller"
189
ENCODING_KEY = "rpc-encoding"
@@ -42,11 +33,6 @@ def _replace_details(self, client_call_details: ClientCallDetails) -> ClientCall
4233
else:
4334
metadata += self._metadata
4435

45-
return _ClientCallDetails(
46-
method=client_call_details.method,
47-
# YARPC seems to require a TTL value
48-
timeout=client_call_details.timeout or 60.0,
49-
metadata=metadata,
50-
credentials=client_call_details.credentials,
51-
wait_for_ready=client_call_details.wait_for_ready,
52-
)
36+
# Namedtuple methods start with an underscore to avoid conflicts and aren't actually private
37+
# noinspection PyProtectedMember
38+
return client_call_details._replace(metadata=metadata, timeout=client_call_details.timeout or 60.0)

cadence/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from grpc import ChannelCredentials, Compression
66

77
from cadence._internal.rpc.error import CadenceErrorInterceptor
8+
from cadence._internal.rpc.retry import RetryInterceptor
89
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
910
from cadence.api.v1.service_domain_pb2_grpc import DomainAPIStub
1011
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
@@ -91,8 +92,9 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
9192

9293
def _create_channel(options: ClientOptions) -> Channel:
9394
interceptors = list(options["interceptors"])
94-
interceptors.append(CadenceErrorInterceptor())
9595
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
96+
interceptors.append(RetryInterceptor())
97+
interceptors.append(CadenceErrorInterceptor())
9698

9799
if options["credentials"]:
98100
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from concurrent import futures
2+
from typing import Tuple, Type
3+
4+
import pytest
5+
from google.protobuf import any_pb2
6+
from google.rpc import status_pb2, code_pb2
7+
from grpc import server
8+
from grpc.aio import insecure_channel
9+
from grpc_status.rpc_status import to_status
10+
11+
from cadence._internal.rpc.error import CadenceErrorInterceptor
12+
from cadence.api.v1 import error_pb2, service_workflow_pb2_grpc
13+
14+
from cadence._internal.rpc.retry import ExponentialRetryPolicy, RetryInterceptor
15+
from cadence.api.v1.service_workflow_pb2 import DescribeWorkflowExecutionResponse, \
16+
DescribeWorkflowExecutionRequest, GetWorkflowExecutionHistoryRequest
17+
from cadence.error import CadenceError, FeatureNotEnabledError, EntityNotExistsError
18+
19+
simple_policy = ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=2, max_interval=10, max_attempts=6)
20+
21+
@pytest.mark.parametrize(
22+
"policy,params,expected",
23+
[
24+
pytest.param(
25+
simple_policy, (1, 0.0, 100.0), 1, id="happy path"
26+
),
27+
pytest.param(
28+
simple_policy, (2, 0.0, 100.0), 2, id="second attempt"
29+
),
30+
pytest.param(
31+
simple_policy, (3, 0.0, 100.0), 4, id="third attempt"
32+
),
33+
pytest.param(
34+
simple_policy, (5, 0.0, 100.0), 10, id="capped by max_interval"
35+
),
36+
pytest.param(
37+
simple_policy, (6, 0.0, 100.0), None, id="out of attempts"
38+
),
39+
pytest.param(
40+
simple_policy, (1, 100.0, 100.0), None, id="timeout"
41+
),
42+
pytest.param(
43+
simple_policy, (1, 99.0, 100.0), None, id="backoff causes timeout"
44+
),
45+
pytest.param(
46+
ExponentialRetryPolicy(initial_interval=1, backoff_coefficient=1, max_interval=10, max_attempts=0), (100, 0.0, 100.0), 1, id="unlimited retries"
47+
),
48+
]
49+
)
50+
def test_next_delay(policy: ExponentialRetryPolicy, params: Tuple[int, float, float], expected: float | None):
51+
assert policy.next_delay(*params) == expected
52+
53+
54+
class FakeService(service_workflow_pb2_grpc.WorkflowAPIServicer):
55+
def __init__(self) -> None:
56+
super().__init__()
57+
self.port = None
58+
self.counter = 0
59+
60+
# Retryable only because it's GetWorkflowExecutionHistory
61+
def GetWorkflowExecutionHistory(self, request: GetWorkflowExecutionHistoryRequest, context):
62+
self.counter += 1
63+
64+
detail = any_pb2.Any()
65+
detail.Pack(error_pb2.EntityNotExistsError(current_cluster=request.domain, active_cluster="active"))
66+
status_proto = status_pb2.Status(
67+
code=code_pb2.NOT_FOUND,
68+
message="message",
69+
details=[detail],
70+
)
71+
context.abort_with_status(to_status(status_proto))
72+
# Unreachable
73+
74+
75+
# Not retryable
76+
def DescribeWorkflowExecution(self, request: DescribeWorkflowExecutionRequest, context):
77+
self.counter += 1
78+
79+
if request.domain == "success":
80+
return DescribeWorkflowExecutionResponse()
81+
elif request.domain == "retryable":
82+
code = code_pb2.RESOURCE_EXHAUSTED
83+
elif request.domain == "maybe later":
84+
if self.counter >= 3:
85+
return DescribeWorkflowExecutionResponse()
86+
87+
code = code_pb2.RESOURCE_EXHAUSTED
88+
else:
89+
code = code_pb2.PERMISSION_DENIED
90+
91+
detail = any_pb2.Any()
92+
detail.Pack(error_pb2.FeatureNotEnabledError(feature_flag="the flag"))
93+
status_proto = status_pb2.Status(
94+
code=code,
95+
message="message",
96+
details=[detail],
97+
)
98+
context.abort_with_status(to_status(status_proto))
99+
# Unreachable
100+
101+
102+
@pytest.fixture(scope="module")
103+
def fake_service():
104+
fake = FakeService()
105+
sync_server = server(futures.ThreadPoolExecutor(max_workers=1))
106+
service_workflow_pb2_grpc.add_WorkflowAPIServicer_to_server(fake, sync_server)
107+
fake.port = sync_server.add_insecure_port("[::]:0")
108+
sync_server.start()
109+
yield fake
110+
sync_server.stop(grace=None)
111+
112+
TEST_POLICY = ExponentialRetryPolicy(initial_interval=0, backoff_coefficient=0, max_interval=10, max_attempts=10)
113+
114+
@pytest.mark.usefixtures("fake_service")
115+
@pytest.mark.parametrize(
116+
"case,expected_calls,expected_err",
117+
[
118+
pytest.param(
119+
"success", 1, None, id="happy path"
120+
),
121+
pytest.param(
122+
"maybe later", 3, None, id="retries then success"
123+
),
124+
pytest.param(
125+
"not retryable", 1, FeatureNotEnabledError, id="not retryable"
126+
),
127+
pytest.param(
128+
"retryable", TEST_POLICY.max_attempts, FeatureNotEnabledError, id="retries exhausted"
129+
),
130+
131+
]
132+
)
133+
@pytest.mark.asyncio
134+
async def test_retryable_error(fake_service, case: str, expected_calls: int, expected_err: Type[CadenceError]):
135+
fake_service.counter = 0
136+
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
137+
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
138+
if expected_err:
139+
with pytest.raises(expected_err):
140+
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)
141+
else:
142+
await stub.DescribeWorkflowExecution(DescribeWorkflowExecutionRequest(domain=case), timeout=10)
143+
144+
assert fake_service.counter == expected_calls
145+
146+
@pytest.mark.usefixtures("fake_service")
147+
@pytest.mark.parametrize(
148+
"case,expected_calls",
149+
[
150+
pytest.param(
151+
"active", 1, id="not retryable"
152+
),
153+
pytest.param(
154+
"not active", TEST_POLICY.max_attempts, id="retries exhausted"
155+
),
156+
157+
]
158+
)
159+
@pytest.mark.asyncio
160+
async def test_workflow_history(fake_service, case: str, expected_calls: int):
161+
fake_service.counter = 0
162+
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[RetryInterceptor(TEST_POLICY), CadenceErrorInterceptor()]) as channel:
163+
stub = service_workflow_pb2_grpc.WorkflowAPIStub(channel)
164+
with pytest.raises(EntityNotExistsError):
165+
await stub.GetWorkflowExecutionHistory(GetWorkflowExecutionHistoryRequest(domain=case), timeout=10)
166+
167+
assert fake_service.counter == expected_calls

0 commit comments

Comments
 (0)