1+ from concurrent import futures
2+ from typing import Tuple , Type
3+
4+ import pytest
5+ from google .protobuf import any_pb2
6+ from google .rpc import status_pb2 , code_pb2
7+ from grpc import server
8+ from grpc .aio import insecure_channel
9+ from grpc_status .rpc_status import to_status
10+
11+ from cadence ._internal .rpc .error import CadenceErrorInterceptor
12+ from cadence .api .v1 import error_pb2 , service_workflow_pb2_grpc
13+
14+ from cadence ._internal .rpc .retry import ExponentialRetryPolicy , RetryInterceptor
15+ from cadence .api .v1 .service_workflow_pb2 import DescribeWorkflowExecutionResponse , \
16+ DescribeWorkflowExecutionRequest , GetWorkflowExecutionHistoryRequest
17+ from cadence .error import CadenceError , FeatureNotEnabledError , EntityNotExistsError
18+
19+ simple_policy = ExponentialRetryPolicy (initial_interval = 1 , backoff_coefficient = 2 , max_interval = 10 , max_attempts = 6 )
20+
21+ @pytest .mark .parametrize (
22+ "policy,params,expected" ,
23+ [
24+ pytest .param (
25+ simple_policy , (1 , 0.0 , 100.0 ), 1 , id = "happy path"
26+ ),
27+ pytest .param (
28+ simple_policy , (2 , 0.0 , 100.0 ), 2 , id = "second attempt"
29+ ),
30+ pytest .param (
31+ simple_policy , (3 , 0.0 , 100.0 ), 4 , id = "third attempt"
32+ ),
33+ pytest .param (
34+ simple_policy , (5 , 0.0 , 100.0 ), 10 , id = "capped by max_interval"
35+ ),
36+ pytest .param (
37+ simple_policy , (6 , 0.0 , 100.0 ), None , id = "out of attempts"
38+ ),
39+ pytest .param (
40+ simple_policy , (1 , 100.0 , 100.0 ), None , id = "timeout"
41+ ),
42+ pytest .param (
43+ simple_policy , (1 , 99.0 , 100.0 ), None , id = "backoff causes timeout"
44+ ),
45+ pytest .param (
46+ ExponentialRetryPolicy (initial_interval = 1 , backoff_coefficient = 1 , max_interval = 10 , max_attempts = 0 ), (100 , 0.0 , 100.0 ), 1 , id = "unlimited retries"
47+ ),
48+ ]
49+ )
50+ def test_next_delay (policy : ExponentialRetryPolicy , params : Tuple [int , float , float ], expected : float | None ):
51+ assert policy .next_delay (* params ) == expected
52+
53+
54+ class FakeService (service_workflow_pb2_grpc .WorkflowAPIServicer ):
55+ def __init__ (self ) -> None :
56+ super ().__init__ ()
57+ self .port = None
58+ self .counter = 0
59+
60+ # Retryable only because it's GetWorkflowExecutionHistory
61+ def GetWorkflowExecutionHistory (self , request : GetWorkflowExecutionHistoryRequest , context ):
62+ self .counter += 1
63+
64+ detail = any_pb2 .Any ()
65+ detail .Pack (error_pb2 .EntityNotExistsError (current_cluster = request .domain , active_cluster = "active" ))
66+ status_proto = status_pb2 .Status (
67+ code = code_pb2 .NOT_FOUND ,
68+ message = "message" ,
69+ details = [detail ],
70+ )
71+ context .abort_with_status (to_status (status_proto ))
72+ # Unreachable
73+
74+
75+ # Not retryable
76+ def DescribeWorkflowExecution (self , request : DescribeWorkflowExecutionRequest , context ):
77+ self .counter += 1
78+
79+ if request .domain == "success" :
80+ return DescribeWorkflowExecutionResponse ()
81+ elif request .domain == "retryable" :
82+ code = code_pb2 .RESOURCE_EXHAUSTED
83+ elif request .domain == "maybe later" :
84+ if self .counter >= 3 :
85+ return DescribeWorkflowExecutionResponse ()
86+
87+ code = code_pb2 .RESOURCE_EXHAUSTED
88+ else :
89+ code = code_pb2 .PERMISSION_DENIED
90+
91+ detail = any_pb2 .Any ()
92+ detail .Pack (error_pb2 .FeatureNotEnabledError (feature_flag = "the flag" ))
93+ status_proto = status_pb2 .Status (
94+ code = code ,
95+ message = "message" ,
96+ details = [detail ],
97+ )
98+ context .abort_with_status (to_status (status_proto ))
99+ # Unreachable
100+
101+
102+ @pytest .fixture (scope = "module" )
103+ def fake_service ():
104+ fake = FakeService ()
105+ sync_server = server (futures .ThreadPoolExecutor (max_workers = 1 ))
106+ service_workflow_pb2_grpc .add_WorkflowAPIServicer_to_server (fake , sync_server )
107+ fake .port = sync_server .add_insecure_port ("[::]:0" )
108+ sync_server .start ()
109+ yield fake
110+ sync_server .stop (grace = None )
111+
112+ TEST_POLICY = ExponentialRetryPolicy (initial_interval = 0 , backoff_coefficient = 0 , max_interval = 10 , max_attempts = 10 )
113+
114+ @pytest .mark .usefixtures ("fake_service" )
115+ @pytest .mark .parametrize (
116+ "case,expected_calls,expected_err" ,
117+ [
118+ pytest .param (
119+ "success" , 1 , None , id = "happy path"
120+ ),
121+ pytest .param (
122+ "maybe later" , 3 , None , id = "retries then success"
123+ ),
124+ pytest .param (
125+ "not retryable" , 1 , FeatureNotEnabledError , id = "not retryable"
126+ ),
127+ pytest .param (
128+ "retryable" , TEST_POLICY .max_attempts , FeatureNotEnabledError , id = "retries exhausted"
129+ ),
130+
131+ ]
132+ )
133+ @pytest .mark .asyncio
134+ async def test_retryable_error (fake_service , case : str , expected_calls : int , expected_err : Type [CadenceError ]):
135+ fake_service .counter = 0
136+ async with insecure_channel (f"[::]:{ fake_service .port } " , interceptors = [RetryInterceptor (TEST_POLICY ), CadenceErrorInterceptor ()]) as channel :
137+ stub = service_workflow_pb2_grpc .WorkflowAPIStub (channel )
138+ if expected_err :
139+ with pytest .raises (expected_err ):
140+ await stub .DescribeWorkflowExecution (DescribeWorkflowExecutionRequest (domain = case ), timeout = 10 )
141+ else :
142+ await stub .DescribeWorkflowExecution (DescribeWorkflowExecutionRequest (domain = case ), timeout = 10 )
143+
144+ assert fake_service .counter == expected_calls
145+
146+ @pytest .mark .usefixtures ("fake_service" )
147+ @pytest .mark .parametrize (
148+ "case,expected_calls" ,
149+ [
150+ pytest .param (
151+ "active" , 1 , id = "not retryable"
152+ ),
153+ pytest .param (
154+ "not active" , TEST_POLICY .max_attempts , id = "retries exhausted"
155+ ),
156+
157+ ]
158+ )
159+ @pytest .mark .asyncio
160+ async def test_workflow_history (fake_service , case : str , expected_calls : int ):
161+ fake_service .counter = 0
162+ async with insecure_channel (f"[::]:{ fake_service .port } " , interceptors = [RetryInterceptor (TEST_POLICY ), CadenceErrorInterceptor ()]) as channel :
163+ stub = service_workflow_pb2_grpc .WorkflowAPIStub (channel )
164+ with pytest .raises (EntityNotExistsError ):
165+ await stub .GetWorkflowExecutionHistory (GetWorkflowExecutionHistoryRequest (domain = case ), timeout = 10 )
166+
167+ assert fake_service .counter == expected_calls
0 commit comments