-
Notifications
You must be signed in to change notification settings - Fork 5
Map AioRpcError to Cadence Error Types #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+339
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| from typing import Callable, Any, Optional, Generator, TypeVar | ||
|
|
||
| import grpc | ||
| from google.rpc.status_pb2 import Status # type: ignore | ||
| from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails, AioRpcError, UnaryUnaryCall, Metadata | ||
| from grpc_status.rpc_status import from_call # type: ignore | ||
|
|
||
| from cadence.api.v1 import error_pb2 | ||
| from cadence import error | ||
|
|
||
|
|
||
| RequestType = TypeVar("RequestType") | ||
| ResponseType = TypeVar("ResponseType") | ||
| DoneCallbackType = Callable[[Any], None] | ||
|
|
||
|
|
||
| # A UnaryUnaryCall is an awaitable type returned by GRPC's aio support. | ||
| # We need to take the UnaryUnaryCall we receive and return one that remaps the exception. | ||
| # It doesn't have any functions to compose operations together, so our only option is to wrap it. | ||
| # If the interceptor directly throws an exception other than AioRpcError it breaks GRPC | ||
| class CadenceErrorUnaryUnaryCall(UnaryUnaryCall[RequestType, ResponseType]): | ||
|
|
||
| def __init__(self, wrapped: UnaryUnaryCall[RequestType, ResponseType]): | ||
| super().__init__() | ||
| self._wrapped = wrapped | ||
|
|
||
| def __await__(self) -> Generator[Any, None, ResponseType]: | ||
| try: | ||
| response = yield from self._wrapped.__await__() # type: ResponseType | ||
| return response | ||
| except AioRpcError as e: | ||
| raise map_error(e) | ||
|
|
||
| async def initial_metadata(self) -> Metadata: | ||
| return await self._wrapped.initial_metadata() | ||
|
|
||
| async def trailing_metadata(self) -> Metadata: | ||
| return await self._wrapped.trailing_metadata() | ||
|
|
||
| async def code(self) -> grpc.StatusCode: | ||
| return await self._wrapped.code() | ||
|
|
||
| async def details(self) -> str: | ||
| return await self._wrapped.details() # type: ignore | ||
|
|
||
| async def wait_for_connection(self) -> None: | ||
| await self._wrapped.wait_for_connection() | ||
|
|
||
| def cancelled(self) -> bool: | ||
| return self._wrapped.cancelled() # type: ignore | ||
|
|
||
| def done(self) -> bool: | ||
| return self._wrapped.done() # type: ignore | ||
|
|
||
| def time_remaining(self) -> Optional[float]: | ||
| return self._wrapped.time_remaining() # type: ignore | ||
|
|
||
| def cancel(self) -> bool: | ||
| return self._wrapped.cancel() # type: ignore | ||
|
|
||
| def add_done_callback(self, callback: DoneCallbackType) -> None: | ||
| self._wrapped.add_done_callback(callback) | ||
|
|
||
|
|
||
| class CadenceErrorInterceptor(UnaryUnaryClientInterceptor): | ||
|
|
||
| async def intercept_unary_unary( | ||
| self, | ||
| continuation: Callable[[ClientCallDetails, Any], Any], | ||
| client_call_details: ClientCallDetails, | ||
| request: Any | ||
| ) -> Any: | ||
| rpc_call = await continuation(client_call_details, request) | ||
| return CadenceErrorUnaryUnaryCall(rpc_call) | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def map_error(e: AioRpcError) -> error.CadenceError: | ||
| status: Status | None = from_call(e) | ||
| if not status or not status.details: | ||
| return error.CadenceError(e.details(), e.code()) | ||
|
|
||
| details = status.details[0] | ||
| if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR): | ||
| already_started = error_pb2.WorkflowExecutionAlreadyStartedError() | ||
| details.Unpack(already_started) | ||
| return error.WorkflowExecutionAlreadyStartedError(e.details(), e.code(), already_started.start_request_id, already_started.run_id) | ||
| elif details.Is(error_pb2.EntityNotExistsError.DESCRIPTOR): | ||
| not_exists = error_pb2.EntityNotExistsError() | ||
| details.Unpack(not_exists) | ||
| return error.EntityNotExistsError(e.details(), e.code(), not_exists.current_cluster, not_exists.active_cluster, list(not_exists.active_clusters)) | ||
| elif details.Is(error_pb2.WorkflowExecutionAlreadyCompletedError.DESCRIPTOR): | ||
| return error.WorkflowExecutionAlreadyCompletedError(e.details(), e.code()) | ||
| elif details.Is(error_pb2.DomainNotActiveError.DESCRIPTOR): | ||
| not_active = error_pb2.DomainNotActiveError() | ||
| details.Unpack(not_active) | ||
| return error.DomainNotActiveError(e.details(), e.code(), not_active.domain, not_active.current_cluster, not_active.active_cluster, list(not_active.active_clusters)) | ||
| elif details.Is(error_pb2.ClientVersionNotSupportedError.DESCRIPTOR): | ||
| not_supported = error_pb2.ClientVersionNotSupportedError() | ||
| details.Unpack(not_supported) | ||
| return error.ClientVersionNotSupportedError(e.details(), e.code(), not_supported.feature_version, not_supported.client_impl, not_supported.supported_versions) | ||
| elif details.Is(error_pb2.FeatureNotEnabledError.DESCRIPTOR): | ||
| not_enabled = error_pb2.FeatureNotEnabledError() | ||
| details.Unpack(not_enabled) | ||
| return error.FeatureNotEnabledError(e.details(), e.code(), not_enabled.feature_flag) | ||
| elif details.Is(error_pb2.CancellationAlreadyRequestedError.DESCRIPTOR): | ||
| return error.CancellationAlreadyRequestedError(e.details(), e.code()) | ||
| elif details.Is(error_pb2.DomainAlreadyExistsError.DESCRIPTOR): | ||
| return error.DomainAlreadyExistsError(e.details(), e.code()) | ||
| elif details.Is(error_pb2.LimitExceededError.DESCRIPTOR): | ||
| return error.LimitExceededError(e.details(), e.code()) | ||
| elif details.Is(error_pb2.QueryFailedError.DESCRIPTOR): | ||
| return error.QueryFailedError(e.details(), e.code()) | ||
| elif details.Is(error_pb2.ServiceBusyError.DESCRIPTOR): | ||
| service_busy = error_pb2.ServiceBusyError() | ||
| details.Unpack(service_busy) | ||
| return error.ServiceBusyError(e.details(), e.code(), service_busy.reason) | ||
| else: | ||
| return error.CadenceError(e.details(), e.code()) | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import grpc | ||
|
|
||
|
|
||
| class CadenceError(Exception): | ||
|
|
||
| def __init__(self, message: str, code: grpc.StatusCode, *args): | ||
| super().__init__(message, code, *args) | ||
| self.code = code | ||
| pass | ||
|
|
||
|
|
||
| class WorkflowExecutionAlreadyStartedError(CadenceError): | ||
|
|
||
| def __init__(self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str) -> None: | ||
| super().__init__(message, code, start_request_id, run_id) | ||
| self.start_request_id = start_request_id | ||
| self.run_id = run_id | ||
|
|
||
| class EntityNotExistsError(CadenceError): | ||
|
|
||
| def __init__(self, message: str, code: grpc.StatusCode, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: | ||
| super().__init__(message, code, current_cluster, active_cluster, active_clusters) | ||
| self.current_cluster = current_cluster | ||
| self.active_cluster = active_cluster | ||
| self.active_clusters = active_clusters | ||
|
|
||
| class WorkflowExecutionAlreadyCompletedError(CadenceError): | ||
| pass | ||
|
|
||
| class DomainNotActiveError(CadenceError): | ||
| def __init__(self, message: str, code: grpc.StatusCode, domain: str, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None: | ||
| super().__init__(message, code, domain, current_cluster, active_cluster, active_clusters) | ||
| self.domain = domain | ||
| self.current_cluster = current_cluster | ||
| self.active_cluster = active_cluster | ||
| self.active_clusters = active_clusters | ||
|
|
||
| class ClientVersionNotSupportedError(CadenceError): | ||
| def __init__(self, message: str, code: grpc.StatusCode, feature_version: str, client_impl: str, supported_versions: str) -> None: | ||
| super().__init__(message, code, feature_version, client_impl, supported_versions) | ||
| self.feature_version = feature_version | ||
| self.client_impl = client_impl | ||
| self.supported_versions = supported_versions | ||
|
|
||
| class FeatureNotEnabledError(CadenceError): | ||
| def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None: | ||
| super().__init__(message, code, feature_flag) | ||
| self.feature_flag = feature_flag | ||
|
|
||
| class CancellationAlreadyRequestedError(CadenceError): | ||
| pass | ||
|
|
||
| class DomainAlreadyExistsError(CadenceError): | ||
| pass | ||
|
|
||
| class LimitExceededError(CadenceError): | ||
| pass | ||
|
|
||
| class QueryFailedError(CadenceError): | ||
| pass | ||
|
|
||
| class ServiceBusyError(CadenceError): | ||
| def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None: | ||
| super().__init__(message, code, reason) | ||
| self.reason = reason | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| from concurrent import futures | ||
|
|
||
| import pytest | ||
| from google.protobuf import any_pb2 | ||
| from google.rpc import code_pb2, status_pb2 | ||
| from grpc import Status, StatusCode, server | ||
| from grpc.aio import insecure_channel | ||
| from grpc_status.rpc_status import to_status | ||
|
|
||
| from cadence._internal.rpc.error import CadenceErrorInterceptor | ||
| from cadence.api.v1 import error_pb2, service_meta_pb2_grpc | ||
| from cadence import error | ||
| from google.protobuf.message import Message | ||
|
|
||
| from cadence.api.v1.service_meta_pb2 import HealthRequest, HealthResponse | ||
| from cadence.error import CadenceError | ||
|
|
||
|
|
||
| class FakeService(service_meta_pb2_grpc.MetaAPIServicer): | ||
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.status: Status | None = None | ||
| self.port: int | None = None | ||
|
|
||
| def Health(self, request, context): | ||
| if temp := self.status: | ||
| self.status = None | ||
| context.abort_with_status(temp) | ||
| return HealthResponse(ok=True) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def fake_service(): | ||
| fake = FakeService() | ||
| sync_server = server(futures.ThreadPoolExecutor(max_workers=1)) | ||
| service_meta_pb2_grpc.add_MetaAPIServicer_to_server(fake, sync_server) | ||
| fake.port = sync_server.add_insecure_port("[::]:0") | ||
| sync_server.start() | ||
| yield fake | ||
| sync_server.stop(grace=None) | ||
|
|
||
| @pytest.mark.usefixtures("fake_service") | ||
| @pytest.mark.parametrize( | ||
| "err,expected", | ||
| [ | ||
| pytest.param(None, None,id="no error"), | ||
| pytest.param( | ||
| error_pb2.WorkflowExecutionAlreadyStartedError(start_request_id="start_request", run_id="run_id"), | ||
| error.WorkflowExecutionAlreadyStartedError(message="message", code=StatusCode.INVALID_ARGUMENT, start_request_id="start_request", run_id="run_id"), | ||
| id="WorkflowExecutionAlreadyStartedError"), | ||
| pytest.param( | ||
| error_pb2.EntityNotExistsError(current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), | ||
| error.EntityNotExistsError(message="message", code=StatusCode.INVALID_ARGUMENT, current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), | ||
| id="EntityNotExistsError"), | ||
| pytest.param( | ||
| error_pb2.WorkflowExecutionAlreadyCompletedError(), | ||
| error.WorkflowExecutionAlreadyCompletedError(message="message", code=StatusCode.INVALID_ARGUMENT), | ||
| id="WorkflowExecutionAlreadyCompletedError"), | ||
| pytest.param( | ||
| error_pb2.DomainNotActiveError(domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), | ||
| error.DomainNotActiveError(message="message", code=StatusCode.INVALID_ARGUMENT, domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]), | ||
| id="DomainNotActiveError"), | ||
| pytest.param( | ||
| error_pb2.ClientVersionNotSupportedError(feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), | ||
| error.ClientVersionNotSupportedError(message="message", code=StatusCode.INVALID_ARGUMENT, feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"), | ||
| id="ClientVersionNotSupportedError"), | ||
| pytest.param( | ||
| error_pb2.FeatureNotEnabledError(feature_flag="feature_flag"), | ||
| error.FeatureNotEnabledError(message="message", code=StatusCode.INVALID_ARGUMENT,feature_flag="feature_flag"), | ||
| id="FeatureNotEnabledError"), | ||
| pytest.param( | ||
| error_pb2.CancellationAlreadyRequestedError(), | ||
| error.CancellationAlreadyRequestedError(message="message", code=StatusCode.INVALID_ARGUMENT), | ||
| id="CancellationAlreadyRequestedError"), | ||
| pytest.param( | ||
| error_pb2.DomainAlreadyExistsError(), | ||
| error.DomainAlreadyExistsError(message="message", code=StatusCode.INVALID_ARGUMENT), | ||
| id="DomainAlreadyExistsError"), | ||
| pytest.param( | ||
| error_pb2.LimitExceededError(), | ||
| error.LimitExceededError(message="message", code=StatusCode.INVALID_ARGUMENT), | ||
| id="LimitExceededError"), | ||
| pytest.param( | ||
| error_pb2.QueryFailedError(), | ||
| error.QueryFailedError(message="message", code=StatusCode.INVALID_ARGUMENT), | ||
| id="QueryFailedError"), | ||
| pytest.param( | ||
| error_pb2.ServiceBusyError(reason="reason"), | ||
| error.ServiceBusyError(message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason"), | ||
| id="ServiceBusyError"), | ||
| pytest.param( | ||
| to_status(status_pb2.Status(code=code_pb2.PERMISSION_DENIED, message="no permission")), | ||
| error.CadenceError(message="no permission", code=StatusCode.PERMISSION_DENIED), | ||
| id="unknown error type"), | ||
| ] | ||
| ) | ||
| @pytest.mark.asyncio | ||
| async def test_map_error(fake_service, err: Message | Status, expected: CadenceError): | ||
| async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()]) as channel: | ||
| stub = service_meta_pb2_grpc.MetaAPIStub(channel) | ||
| if expected is None: | ||
| response = await stub.Health(HealthRequest(), timeout=1) | ||
| assert response == HealthResponse(ok=True) | ||
| else: | ||
| if isinstance(err, Message): | ||
| fake_service.status = details_to_status(err) | ||
| else: | ||
| fake_service.status = err | ||
| with pytest.raises(type(expected)) as exc_info: | ||
| await stub.Health(HealthRequest(), timeout=1) | ||
| assert exc_info.value.args == expected.args | ||
|
|
||
| def details_to_status(message: Message) -> Status: | ||
| detail = any_pb2.Any() | ||
| detail.Pack(message) | ||
| status_proto = status_pb2.Status( | ||
| code=code_pb2.INVALID_ARGUMENT, | ||
| message="message", | ||
| details=[detail], | ||
| ) | ||
| return to_status(status_proto) | ||
|
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if we should inherit RuntimeError or just Exception here. But it's ok to just land it now as we move forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.python.org/3/library/exceptions.html#Exception - "All user-defined exceptions should also be derived from this class.".
Simpler than Java, thankfully