Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 130 additions & 1 deletion cadence/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import os
import socket
from typing import TypedDict, Unpack, Any, cast
import uuid
from dataclasses import dataclass
from datetime import timedelta
from typing import TypedDict, Unpack, Any, cast, Union, Optional, Callable

from grpc import ChannelCredentials, Compression
from google.protobuf.duration_pb2 import Duration

from cadence._internal.rpc.error import CadenceErrorInterceptor
from cadence._internal.rpc.retry import RetryInterceptor
Expand All @@ -11,10 +15,31 @@
from cadence.api.v1.service_worker_pb2_grpc import WorkerAPIStub
from grpc.aio import Channel, ClientInterceptor, secure_channel, insecure_channel
from cadence.api.v1.service_workflow_pb2_grpc import WorkflowAPIStub
from cadence.api.v1.service_workflow_pb2 import StartWorkflowExecutionRequest, StartWorkflowExecutionResponse
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
from cadence.api.v1.tasklist_pb2 import TaskList
from cadence.data_converter import DataConverter, DefaultDataConverter
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter



@dataclass
class StartWorkflowOptions:
"""Options for starting a workflow execution."""
task_list: str
execution_start_to_close_timeout: Optional[timedelta] = None
task_start_to_close_timeout: Optional[timedelta] = None
workflow_id: Optional[str] = None
cron_schedule: Optional[str] = None

def __post_init__(self):
"""Validate required fields after initialization."""
if not self.task_list:
raise ValueError("task_list is required")
if not self.execution_start_to_close_timeout and not self.task_start_to_close_timeout:
raise ValueError("either execution_start_to_close_timeout or task_start_to_close_timeout is required")


class ClientOptions(TypedDict, total=False):
domain: str
target: str
Expand Down Expand Up @@ -88,6 +113,110 @@ async def __aenter__(self) -> 'Client':
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.close()

async def _build_start_workflow_request(
self,
workflow: Union[str, Callable],
args: tuple[Any, ...],
options: StartWorkflowOptions
) -> StartWorkflowExecutionRequest:
"""Build a StartWorkflowExecutionRequest from parameters."""
# Generate workflow ID if not provided
workflow_id = options.workflow_id or str(uuid.uuid4())

# Determine workflow type name
if isinstance(workflow, str):
workflow_type_name = workflow
else:
# For callable, use function name or __name__ attribute
workflow_type_name = getattr(workflow, '__name__', str(workflow))

# Encode input arguments
input_payload = None
if args:
try:
input_payload = await self.data_converter.to_data(list(args))
except Exception as e:
raise ValueError(f"Failed to encode workflow arguments: {e}")

# Convert timedelta to protobuf Duration
execution_timeout = None
if options.execution_start_to_close_timeout:
execution_timeout = Duration()
execution_timeout.FromTimedelta(options.execution_start_to_close_timeout)

task_timeout = None
if options.task_start_to_close_timeout:
task_timeout = Duration()
task_timeout.FromTimedelta(options.task_start_to_close_timeout)

# Build the request
request = StartWorkflowExecutionRequest(
domain=self.domain,
workflow_id=workflow_id,
workflow_type=WorkflowType(name=workflow_type_name),
task_list=TaskList(name=options.task_list),
identity=self.identity,
request_id=str(uuid.uuid4())
)

# Set optional fields
if input_payload:
request.input.CopyFrom(input_payload)
if execution_timeout:
request.execution_start_to_close_timeout.CopyFrom(execution_timeout)
if task_timeout:
request.task_start_to_close_timeout.CopyFrom(task_timeout)
if options.cron_schedule:
request.cron_schedule = options.cron_schedule

return request

async def start_workflow(
self,
workflow: Union[str, Callable],
*args,
**options_kwargs
) -> WorkflowExecution:
"""
Start a workflow execution asynchronously.

Args:
workflow: Workflow function or workflow type name string
*args: Arguments to pass to the workflow
**options_kwargs: StartWorkflowOptions as keyword arguments

Returns:
WorkflowExecution with workflow_id and run_id

Raises:
ValueError: If required parameters are missing or invalid
Exception: If the gRPC call fails
"""
# Convert kwargs to StartWorkflowOptions
options = StartWorkflowOptions(**options_kwargs)

# Build the gRPC request
request = await self._build_start_workflow_request(workflow, args, options)

# Execute the gRPC call
try:
response: StartWorkflowExecutionResponse = await self.workflow_stub.StartWorkflowExecution(request)

# Emit metrics if available
if self.metrics_emitter:
# TODO: Add workflow start metrics similar to Go client
pass

execution = WorkflowExecution()
execution.workflow_id = request.workflow_id
execution.run_id = response.run_id
return execution
except Exception as e:
raise Exception(f"Failed to start workflow: {e}") from e




def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
if "target" not in options:
raise ValueError("target must be specified")
Expand Down
Loading