Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions cadence/_internal/rpc/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from typing import Callable, Any, Optional, Generator, TypeVar

import grpc
from google.rpc.status_pb2 import Status # type: ignore
from grpc.aio import UnaryUnaryClientInterceptor, ClientCallDetails, AioRpcError, UnaryUnaryCall, Metadata
from grpc_status.rpc_status import from_call # type: ignore

from cadence.api.v1 import error_pb2
from cadence import error


RequestType = TypeVar("RequestType")
ResponseType = TypeVar("ResponseType")
DoneCallbackType = Callable[[Any], None]


# A UnaryUnaryCall is an awaitable type returned by GRPC's aio support.
# We need to take the UnaryUnaryCall we receive and return one that remaps the exception.
# It doesn't have any functions to compose operations together, so our only option is to wrap it.
# If the interceptor directly throws an exception other than AioRpcError it breaks GRPC
class CadenceErrorUnaryUnaryCall(UnaryUnaryCall[RequestType, ResponseType]):

def __init__(self, wrapped: UnaryUnaryCall[RequestType, ResponseType]):
super().__init__()
self._wrapped = wrapped

def __await__(self) -> Generator[Any, None, ResponseType]:
try:
response = yield from self._wrapped.__await__() # type: ResponseType
return response
except AioRpcError as e:
raise map_error(e)

async def initial_metadata(self) -> Metadata:
return await self._wrapped.initial_metadata()

async def trailing_metadata(self) -> Metadata:
return await self._wrapped.trailing_metadata()

async def code(self) -> grpc.StatusCode:
return await self._wrapped.code()

async def details(self) -> str:
return await self._wrapped.details() # type: ignore

async def wait_for_connection(self) -> None:
await self._wrapped.wait_for_connection()

def cancelled(self) -> bool:
return self._wrapped.cancelled() # type: ignore

def done(self) -> bool:
return self._wrapped.done() # type: ignore

def time_remaining(self) -> Optional[float]:
return self._wrapped.time_remaining() # type: ignore

def cancel(self) -> bool:
return self._wrapped.cancel() # type: ignore

def add_done_callback(self, callback: DoneCallbackType) -> None:
self._wrapped.add_done_callback(callback)


class CadenceErrorInterceptor(UnaryUnaryClientInterceptor):

async def intercept_unary_unary(
self,
continuation: Callable[[ClientCallDetails, Any], Any],
client_call_details: ClientCallDetails,
request: Any
) -> Any:
rpc_call = await continuation(client_call_details, request)
return CadenceErrorUnaryUnaryCall(rpc_call)




def map_error(e: AioRpcError) -> error.CadenceError:
status: Status | None = from_call(e)
if not status or not status.details:
return error.CadenceError(e.details(), e.code())

details = status.details[0]
if details.Is(error_pb2.WorkflowExecutionAlreadyStartedError.DESCRIPTOR):
already_started = error_pb2.WorkflowExecutionAlreadyStartedError()
details.Unpack(already_started)
return error.WorkflowExecutionAlreadyStartedError(e.details(), e.code(), already_started.start_request_id, already_started.run_id)
elif details.Is(error_pb2.EntityNotExistsError.DESCRIPTOR):
not_exists = error_pb2.EntityNotExistsError()
details.Unpack(not_exists)
return error.EntityNotExistsError(e.details(), e.code(), not_exists.current_cluster, not_exists.active_cluster, list(not_exists.active_clusters))
elif details.Is(error_pb2.WorkflowExecutionAlreadyCompletedError.DESCRIPTOR):
return error.WorkflowExecutionAlreadyCompletedError(e.details(), e.code())
elif details.Is(error_pb2.DomainNotActiveError.DESCRIPTOR):
not_active = error_pb2.DomainNotActiveError()
details.Unpack(not_active)
return error.DomainNotActiveError(e.details(), e.code(), not_active.domain, not_active.current_cluster, not_active.active_cluster, list(not_active.active_clusters))
elif details.Is(error_pb2.ClientVersionNotSupportedError.DESCRIPTOR):
not_supported = error_pb2.ClientVersionNotSupportedError()
details.Unpack(not_supported)
return error.ClientVersionNotSupportedError(e.details(), e.code(), not_supported.feature_version, not_supported.client_impl, not_supported.supported_versions)
elif details.Is(error_pb2.FeatureNotEnabledError.DESCRIPTOR):
not_enabled = error_pb2.FeatureNotEnabledError()
details.Unpack(not_enabled)
return error.FeatureNotEnabledError(e.details(), e.code(), not_enabled.feature_flag)
elif details.Is(error_pb2.CancellationAlreadyRequestedError.DESCRIPTOR):
return error.CancellationAlreadyRequestedError(e.details(), e.code())
elif details.Is(error_pb2.DomainAlreadyExistsError.DESCRIPTOR):
return error.DomainAlreadyExistsError(e.details(), e.code())
elif details.Is(error_pb2.LimitExceededError.DESCRIPTOR):
return error.LimitExceededError(e.details(), e.code())
elif details.Is(error_pb2.QueryFailedError.DESCRIPTOR):
return error.QueryFailedError(e.details(), e.code())
elif details.Is(error_pb2.ServiceBusyError.DESCRIPTOR):
service_busy = error_pb2.ServiceBusyError()
details.Unpack(service_busy)
return error.ServiceBusyError(e.details(), e.code(), service_busy.reason)
else:
return error.CadenceError(e.details(), e.code())

2 changes: 2 additions & 0 deletions cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from grpc import ChannelCredentials, Compression

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence._internal.rpc.yarpc import YarpcMetadataInterceptor
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
Expand Down Expand Up @@ -75,6 +76,7 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:

def _create_channel(options: ClientOptions) -> Channel:
interceptors = list(options["interceptors"])
interceptors.append(CadenceErrorInterceptor())
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))

if options["credentials"]:
Expand Down
65 changes: 65 additions & 0 deletions cadence/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import grpc


class CadenceError(Exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we should inherit RuntimeError or just Exception here. But it's ok to just land it now as we move forward

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://docs.python.org/3/library/exceptions.html#Exception - "All user-defined exceptions should also be derived from this class.".

Simpler than Java, thankfully


def __init__(self, message: str, code: grpc.StatusCode, *args):
super().__init__(message, code, *args)
self.code = code
pass


class WorkflowExecutionAlreadyStartedError(CadenceError):

def __init__(self, message: str, code: grpc.StatusCode, start_request_id: str, run_id: str) -> None:
super().__init__(message, code, start_request_id, run_id)
self.start_request_id = start_request_id
self.run_id = run_id

class EntityNotExistsError(CadenceError):

def __init__(self, message: str, code: grpc.StatusCode, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None:
super().__init__(message, code, current_cluster, active_cluster, active_clusters)
self.current_cluster = current_cluster
self.active_cluster = active_cluster
self.active_clusters = active_clusters

class WorkflowExecutionAlreadyCompletedError(CadenceError):
pass

class DomainNotActiveError(CadenceError):
def __init__(self, message: str, code: grpc.StatusCode, domain: str, current_cluster: str, active_cluster: str, active_clusters: list[str]) -> None:
super().__init__(message, code, domain, current_cluster, active_cluster, active_clusters)
self.domain = domain
self.current_cluster = current_cluster
self.active_cluster = active_cluster
self.active_clusters = active_clusters

class ClientVersionNotSupportedError(CadenceError):
def __init__(self, message: str, code: grpc.StatusCode, feature_version: str, client_impl: str, supported_versions: str) -> None:
super().__init__(message, code, feature_version, client_impl, supported_versions)
self.feature_version = feature_version
self.client_impl = client_impl
self.supported_versions = supported_versions

class FeatureNotEnabledError(CadenceError):
def __init__(self, message: str, code: grpc.StatusCode, feature_flag: str) -> None:
super().__init__(message, code, feature_flag)
self.feature_flag = feature_flag

class CancellationAlreadyRequestedError(CadenceError):
pass

class DomainAlreadyExistsError(CadenceError):
pass

class LimitExceededError(CadenceError):
pass

class QueryFailedError(CadenceError):
pass

class ServiceBusyError(CadenceError):
def __init__(self, message: str, code: grpc.StatusCode, reason: str) -> None:
super().__init__(message, code, reason)
self.reason = reason
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ classifiers = [
requires-python = ">=3.11,<3.14"
dependencies = [
"grpcio==1.71.2",
"grpcio-status>=1.71.2",
"msgspec>=0.19.0",
"protobuf==5.29.1",
"typing-extensions>=4.0.0",
Expand Down
122 changes: 122 additions & 0 deletions tests/cadence/_internal/rpc/test_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from concurrent import futures

import pytest
from google.protobuf import any_pb2
from google.rpc import code_pb2, status_pb2
from grpc import Status, StatusCode, server
from grpc.aio import insecure_channel
from grpc_status.rpc_status import to_status

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence.api.v1 import error_pb2, service_meta_pb2_grpc
from cadence import error
from google.protobuf.message import Message

from cadence.api.v1.service_meta_pb2 import HealthRequest, HealthResponse
from cadence.error import CadenceError


class FakeService(service_meta_pb2_grpc.MetaAPIServicer):
def __init__(self) -> None:
super().__init__()
self.status: Status | None = None
self.port: int | None = None

def Health(self, request, context):
if temp := self.status:
self.status = None
context.abort_with_status(temp)
return HealthResponse(ok=True)


@pytest.fixture(scope="module")
def fake_service():
fake = FakeService()
sync_server = server(futures.ThreadPoolExecutor(max_workers=1))
service_meta_pb2_grpc.add_MetaAPIServicer_to_server(fake, sync_server)
fake.port = sync_server.add_insecure_port("[::]:0")
sync_server.start()
yield fake
sync_server.stop(grace=None)

@pytest.mark.usefixtures("fake_service")
@pytest.mark.parametrize(
"err,expected",
[
pytest.param(None, None,id="no error"),
pytest.param(
error_pb2.WorkflowExecutionAlreadyStartedError(start_request_id="start_request", run_id="run_id"),
error.WorkflowExecutionAlreadyStartedError(message="message", code=StatusCode.INVALID_ARGUMENT, start_request_id="start_request", run_id="run_id"),
id="WorkflowExecutionAlreadyStartedError"),
pytest.param(
error_pb2.EntityNotExistsError(current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
error.EntityNotExistsError(message="message", code=StatusCode.INVALID_ARGUMENT, current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
id="EntityNotExistsError"),
pytest.param(
error_pb2.WorkflowExecutionAlreadyCompletedError(),
error.WorkflowExecutionAlreadyCompletedError(message="message", code=StatusCode.INVALID_ARGUMENT),
id="WorkflowExecutionAlreadyCompletedError"),
pytest.param(
error_pb2.DomainNotActiveError(domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
error.DomainNotActiveError(message="message", code=StatusCode.INVALID_ARGUMENT, domain="domain", current_cluster="current_cluster", active_cluster="active_cluster", active_clusters=["active_clusters"]),
id="DomainNotActiveError"),
pytest.param(
error_pb2.ClientVersionNotSupportedError(feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"),
error.ClientVersionNotSupportedError(message="message", code=StatusCode.INVALID_ARGUMENT, feature_version="feature_version", client_impl="client_impl", supported_versions="supported_versions"),
id="ClientVersionNotSupportedError"),
pytest.param(
error_pb2.FeatureNotEnabledError(feature_flag="feature_flag"),
error.FeatureNotEnabledError(message="message", code=StatusCode.INVALID_ARGUMENT,feature_flag="feature_flag"),
id="FeatureNotEnabledError"),
pytest.param(
error_pb2.CancellationAlreadyRequestedError(),
error.CancellationAlreadyRequestedError(message="message", code=StatusCode.INVALID_ARGUMENT),
id="CancellationAlreadyRequestedError"),
pytest.param(
error_pb2.DomainAlreadyExistsError(),
error.DomainAlreadyExistsError(message="message", code=StatusCode.INVALID_ARGUMENT),
id="DomainAlreadyExistsError"),
pytest.param(
error_pb2.LimitExceededError(),
error.LimitExceededError(message="message", code=StatusCode.INVALID_ARGUMENT),
id="LimitExceededError"),
pytest.param(
error_pb2.QueryFailedError(),
error.QueryFailedError(message="message", code=StatusCode.INVALID_ARGUMENT),
id="QueryFailedError"),
pytest.param(
error_pb2.ServiceBusyError(reason="reason"),
error.ServiceBusyError(message="message", code=StatusCode.INVALID_ARGUMENT, reason="reason"),
id="ServiceBusyError"),
pytest.param(
to_status(status_pb2.Status(code=code_pb2.PERMISSION_DENIED, message="no permission")),
error.CadenceError(message="no permission", code=StatusCode.PERMISSION_DENIED),
id="unknown error type"),
]
)
@pytest.mark.asyncio
async def test_map_error(fake_service, err: Message | Status, expected: CadenceError):
async with insecure_channel(f"[::]:{fake_service.port}", interceptors=[CadenceErrorInterceptor()]) as channel:
stub = service_meta_pb2_grpc.MetaAPIStub(channel)
if expected is None:
response = await stub.Health(HealthRequest(), timeout=1)
assert response == HealthResponse(ok=True)
else:
if isinstance(err, Message):
fake_service.status = details_to_status(err)
else:
fake_service.status = err
with pytest.raises(type(expected)) as exc_info:
await stub.Health(HealthRequest(), timeout=1)
assert exc_info.value.args == expected.args

def details_to_status(message: Message) -> Status:
detail = any_pb2.Any()
detail.Pack(message)
status_proto = status_pb2.Status(
code=code_pb2.INVALID_ARGUMENT,
message="message",
details=[detail],
)
return to_status(status_proto)

28 changes: 28 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.