Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 160 additions & 5 deletions cadence/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import socket
from typing import TypedDict, Unpack, Any, cast
import uuid
from datetime import timedelta
from typing import TypedDict, Unpack, Any, cast, Union, Callable

from grpc import ChannelCredentials, Compression
from google.protobuf.duration_pb2 import Duration

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence._internal.rpc.retry import RetryInterceptor
Expand All @@ -11,10 +14,47 @@
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub
from cadence.api.v1.service_workflow_pb2 import (
StartWorkflowExecutionRequest,
StartWorkflowExecutionResponse,
)
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
from cadence.api.v1.tasklist_pb2 import TaskList
from cadence.data_converter import DataConverter, DefaultDataConverter
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter


class StartWorkflowOptions(TypedDict, total=False):
"""Options for starting a workflow execution."""

task_list: str
execution_start_to_close_timeout: timedelta
workflow_id: str
task_start_to_close_timeout: timedelta
cron_schedule: str


def _validate_and_apply_defaults(options: StartWorkflowOptions) -> StartWorkflowOptions:
"""Validate required fields and apply defaults to StartWorkflowOptions."""
if not options.get("task_list"):
raise ValueError("task_list is required")

execution_timeout = options.get("execution_start_to_close_timeout")
if not execution_timeout:
raise ValueError("execution_start_to_close_timeout is required")
if execution_timeout <= timedelta(0):
raise ValueError("execution_start_to_close_timeout must be greater than 0")

# Apply default for task_start_to_close_timeout if not provided (matching Go/Java clients)
task_timeout = options.get("task_start_to_close_timeout")
if task_timeout is None:
options["task_start_to_close_timeout"] = timedelta(seconds=10)
elif task_timeout <= timedelta(0):
raise ValueError("task_start_to_close_timeout must be greater than 0")

return options


class ClientOptions(TypedDict, total=False):
domain: str
target: str
Expand All @@ -28,6 +68,7 @@ class ClientOptions(TypedDict, total=False):
metrics_emitter: MetricsEmitter
interceptors: list[ClientInterceptor]


_DEFAULT_OPTIONS: ClientOptions = {
"data_converter": DefaultDataConverter(),
"identity": f"{os.getpid()}@{socket.gethostname()}",
Expand All @@ -40,6 +81,7 @@ class ClientOptions(TypedDict, total=False):
"interceptors": [],
}


class Client:
def __init__(self, **kwargs: Unpack[ClientOptions]) -> None:
self._options = _validate_and_copy_defaults(ClientOptions(**kwargs))
Expand Down Expand Up @@ -82,12 +124,112 @@ async def ready(self) -> None:
async def close(self) -> None:
await self._channel.close()

async def __aenter__(self) -> 'Client':
async def __aenter__(self) -> "Client":
return self

async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()

async def _build_start_workflow_request(
self,
workflow: Union[str, Callable],
args: tuple[Any, ...],
options: StartWorkflowOptions,
) -> StartWorkflowExecutionRequest:
"""Build a StartWorkflowExecutionRequest from parameters."""
# Generate workflow ID if not provided
workflow_id = options.get("workflow_id") or str(uuid.uuid4())

# Determine workflow type name
if isinstance(workflow, str):
workflow_type_name = workflow
else:
# For callable, use function name or __name__ attribute
workflow_type_name = getattr(workflow, "__name__", str(workflow))

# Encode input arguments
input_payload = None
if args:
try:
input_payload = await self.data_converter.to_data(list(args))
except Exception as e:
raise ValueError(f"Failed to encode workflow arguments: {e}")

# Convert timedelta to protobuf Duration
execution_timeout = Duration()
execution_timeout.FromTimedelta(options["execution_start_to_close_timeout"])

task_timeout = Duration()
task_timeout.FromTimedelta(options["task_start_to_close_timeout"])

# Build the request
request = StartWorkflowExecutionRequest(
domain=self.domain,
workflow_id=workflow_id,
workflow_type=WorkflowType(name=workflow_type_name),
task_list=TaskList(name=options["task_list"]),
identity=self.identity,
request_id=str(uuid.uuid4()),
)

# Set required timeout fields
request.execution_start_to_close_timeout.CopyFrom(execution_timeout)
request.task_start_to_close_timeout.CopyFrom(task_timeout)

# Set optional fields
if input_payload:
request.input.CopyFrom(input_payload)
if options.get("cron_schedule"):
request.cron_schedule = options["cron_schedule"]

return request

async def start_workflow(
self,
workflow: Union[str, Callable],
*args,
**options_kwargs: Unpack[StartWorkflowOptions],
) -> WorkflowExecution:
"""
Start a workflow execution asynchronously.

Args:
workflow: Workflow function or workflow type name string
*args: Arguments to pass to the workflow
**options_kwargs: StartWorkflowOptions as keyword arguments

Returns:
WorkflowExecution with workflow_id and run_id

Raises:
ValueError: If required parameters are missing or invalid
Exception: If the gRPC call fails
"""
# Convert kwargs to StartWorkflowOptions and validate
options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs))

# Build the gRPC request
request = await self._build_start_workflow_request(workflow, args, options)

# Execute the gRPC call
try:
response: StartWorkflowExecutionResponse = (
await self.workflow_stub.StartWorkflowExecution(request)
)

# Emit metrics if available
if self.metrics_emitter:
# TODO: Add workflow start metrics similar to Go client
pass

execution = WorkflowExecution()
execution.workflow_id = request.workflow_id
execution.run_id = response.run_id
return execution
except Exception:
raise


def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
if "target" not in options:
raise ValueError("target must be specified")
Expand All @@ -105,11 +247,24 @@ def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:

def _create_channel(options: ClientOptions) -> Channel:
interceptors = list(options["interceptors"])
interceptors.append(YarpcMetadataInterceptor(options["service_name"], options["caller_name"]))
interceptors.append(
YarpcMetadataInterceptor(options["service_name"], options["caller_name"])
)
interceptors.append(RetryInterceptor())
interceptors.append(CadenceErrorInterceptor())

if options["credentials"]:
return secure_channel(options["target"], options["credentials"], options["channel_arguments"], options["compression"], interceptors)
return secure_channel(
options["target"],
options["credentials"],
options["channel_arguments"],
options["compression"],
interceptors,
)
else:
return insecure_channel(options["target"], options["channel_arguments"], options["compression"], interceptors)
return insecure_channel(
options["target"],
options["channel_arguments"],
options["compression"],
interceptors,
)
Loading