diff --git a/dapr_agents/workflow/decorators/messaging.py b/dapr_agents/workflow/decorators/messaging.py index 8187eb58..524d5bd0 100644 --- a/dapr_agents/workflow/decorators/messaging.py +++ b/dapr_agents/workflow/decorators/messaging.py @@ -1,11 +1,21 @@ import logging +import warnings from copy import deepcopy from typing import Any, Callable, Optional, get_type_hints + from dapr_agents.workflow.utils.core import is_valid_routable_model from dapr_agents.workflow.utils.messaging import extract_message_models logger = logging.getLogger(__name__) +_MESSAGE_ROUTER_DEPRECATION_MESSAGE = ( + "@message_router (legacy version from dapr_agents.workflow.decorators.messaging) " + "is deprecated and will be removed in a future release. " + "Please migrate to the updated decorator in " + "`dapr_agents.workflow.decorators.routers`, which supports " + "Union types, forward references, and explicit Dapr workflow integration." +) + def message_router( func: Optional[Callable[..., Any]] = None, @@ -16,7 +26,8 @@ def message_router( broadcast: bool = False, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ - Decorator for registering message handlers by inspecting type hints on the 'message' argument. + [DEPRECATED] Legacy decorator for registering message handlers by inspecting type hints + on the 'message' argument. This decorator: - Extracts the expected message model type from function annotations. @@ -36,6 +47,12 @@ def message_router( """ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + warnings.warn( + _MESSAGE_ROUTER_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) + is_workflow = hasattr(f, "_is_workflow") workflow_name = getattr(f, "_workflow_name", None) @@ -56,7 +73,9 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]: ) logger.debug( - f"@message_router: '{f.__name__}' => models {[m.__name__ for m in message_models]}" + "@message_router (legacy): '%s' => models %s", + f.__name__, + [m.__name__ for m in message_models], ) # Attach metadata for later registration diff --git a/dapr_agents/workflow/decorators/routers.py b/dapr_agents/workflow/decorators/routers.py new file mode 100644 index 00000000..3b803cbd --- /dev/null +++ b/dapr_agents/workflow/decorators/routers.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import inspect +import logging +from copy import deepcopy +from typing import ( + Any, + Callable, + Optional, + get_type_hints, +) + +from dapr_agents.workflow.utils.core import is_supported_model +from dapr_agents.workflow.utils.routers import extract_message_models + +logger = logging.getLogger(__name__) + + +def message_router( + func: Optional[Callable[..., Any]] = None, + *, + pubsub: Optional[str] = None, + topic: Optional[str] = None, + dead_letter_topic: Optional[str] = None, + broadcast: bool = False, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Decorate a message handler with routing metadata. + + The handler must accept a parameter named `message`. Its type hint defines the + expected payload model(s), e.g.: + + @message_router(pubsub="pubsub", topic="orders") + def on_order(message: OrderCreated): ... + + @message_router(pubsub="pubsub", topic="events") + def on_event(message: Union[Foo, Bar]): ... + + Args: + func: (optional) bare-decorator form support. + pubsub: Name of the Dapr pub/sub component (required when used with args). + topic: Topic name to subscribe to (required when used with args). + dead_letter_topic: Optional dead-letter topic (defaults to f"{topic}_DEAD"). + broadcast: Optional flag you can use downstream for fan-out semantics. + + Returns: + The original function tagged with `_message_router_data`. + """ + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + # Validate required kwargs only when decorator is used with args + if pubsub is None or topic is None: + raise ValueError( + "`pubsub` and `topic` are required when using @message_router with arguments." + ) + + sig = inspect.signature(f) + if "message" not in sig.parameters: + raise ValueError(f"'{f.__name__}' must have a 'message' parameter.") + + # Resolve forward refs under PEP 563 / future annotations + try: + hints = get_type_hints(f, globalns=f.__globals__) + except Exception: + logger.debug( + "Failed to fully resolve type hints for %s", f.__name__, exc_info=True + ) + hints = getattr(f, "__annotations__", {}) or {} + + raw_hint = hints.get("message") + if raw_hint is None: + raise TypeError( + f"'{f.__name__}' must type-hint the 'message' parameter " + "(e.g., 'message: MyModel' or 'message: Union[A, B]')" + ) + + models = extract_message_models(raw_hint) + if not models: + raise TypeError( + f"Unsupported or unresolved message type for '{f.__name__}': {raw_hint!r}" + ) + + # Optional early validation of supported schema kinds + for m in models: + if not is_supported_model(m): + raise TypeError(f"Unsupported model type in '{f.__name__}': {m!r}") + + data = { + "pubsub": pubsub, + "topic": topic, + "dead_letter_topic": dead_letter_topic + or (f"{topic}_DEAD" if topic else None), + "is_broadcast": broadcast, + "message_schemas": models, # list[type] + "message_types": [m.__name__ for m in models], # list[str] + } + + # Attach metadata; deepcopy for defensive isolation + setattr(f, "_is_message_handler", True) + setattr(f, "_message_router_data", deepcopy(data)) + + logger.debug( + "@message_router: '%s' => models %s (topic=%s, pubsub=%s, broadcast=%s)", + f.__name__, + [m.__name__ for m in models], + topic, + pubsub, + broadcast, + ) + return f + + # Support both @message_router(...) and bare @message_router usage + return decorator if func is None else decorator(func) diff --git a/dapr_agents/workflow/utils/registration.py b/dapr_agents/workflow/utils/registration.py new file mode 100644 index 00000000..9a0f2dce --- /dev/null +++ b/dapr_agents/workflow/utils/registration.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import asyncio +import inspect +import logging +from typing import Any, Callable, Iterable, List, Optional, Type + +from dapr.clients import DaprClient +from dapr.clients.grpc._response import TopicEventResponse +from dapr.common.pubsub.subscription import SubscriptionMessage + +from dapr_agents.workflow.utils.messaging import ( + extract_cloudevent_data, + validate_message_model, +) + +logger = logging.getLogger(__name__) + + +def register_message_handlers( + targets: Iterable[Any], + dapr_client: DaprClient, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> List[Callable[[], None]]: + """Discover and subscribe handlers decorated with `@message_router`. + + Scans each target: + - If the target itself is a decorated function (has `_message_router_data`), it is registered. + - If the target is an object, all its attributes are scanned for decorated callables. + + Subscriptions use Dapr's streaming API (`subscribe_with_handler`) which invokes your handler + on a background thread. This function returns a list of "closer" callables. Invoking a closer + will unsubscribe the corresponding handler. + + Args: + targets: Functions and/or instances to inspect for `_message_router_data`. + dapr_client: Active Dapr client used to create subscriptions. + loop: Event loop to await async handlers. If omitted, uses the running loop + or falls back to `asyncio.get_event_loop()`. + + Returns: + A list of callables. Each callable, when invoked, closes the associated subscription. + """ + # Resolve loop strategy once up front. + if loop is None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.get_event_loop() + + closers: List[Callable[[], None]] = [] + + def _iter_handlers(obj: Any): + """Yield (owner, fn) pairs for decorated handlers on `obj`. + + If `obj` is itself a decorated function, yield (None, obj). + If `obj` is an instance, scan its attributes for decorated callables. + """ + meta = getattr(obj, "_message_router_data", None) + if callable(obj) and meta: + yield None, obj + return + + for name in dir(obj): + fn = getattr(obj, name) + if callable(fn) and getattr(fn, "_message_router_data", None): + yield obj, fn + + for target in targets: + for owner, handler in _iter_handlers(target): + meta = getattr(handler, "_message_router_data") + schemas: List[Type[Any]] = meta.get("message_schemas") or [] + + # Bind method to instance if needed (descriptor protocol). + bound = ( + handler if owner is None else handler.__get__(owner, owner.__class__) + ) + + async def _invoke( + bound_handler: Callable[..., Any], + parsed: Any, + ) -> TopicEventResponse: + """Invoke the user handler (sync or async) and normalize the result.""" + result = bound_handler(parsed) + if inspect.iscoroutine(result): + result = await result + if isinstance(result, TopicEventResponse): + return result + # Treat any truthy/None return as success unless user explicitly returns a response. + return TopicEventResponse("success") + + def _make_handler( + bound_handler: Callable[..., Any], + ) -> Callable[[SubscriptionMessage], TopicEventResponse]: + """Create a Dapr-compatible handler for a single decorated function.""" + + def handler_fn(message: SubscriptionMessage) -> TopicEventResponse: + try: + # 1) Extract payload + CloudEvent metadata (bytes/str/dict are also supported by the extractor) + event_data, metadata = extract_cloudevent_data(message) + + # 2) Validate against the first matching schema (or dict as fallback) + parsed = None + for model in schemas or [dict]: + try: + parsed = validate_message_model(model, event_data) + break + except Exception: + # Try the next schema; log at debug for signal without noise. + logger.debug( + "Schema %r did not match payload; trying next.", + model, + exc_info=True, + ) + continue + + if parsed is None: + # Permanent schema mismatch → drop (DLQ if configured by Dapr) + logger.warning( + "No matching schema for message on topic %r; dropping. Raw payload: %r", + meta["topic"], + event_data, + ) + return TopicEventResponse("drop") + + # 3) Attach CE metadata for downstream consumers + if isinstance(parsed, dict): + parsed["_message_metadata"] = metadata + else: + setattr(parsed, "_message_metadata", metadata) + + # 4) Bridge worker thread → event loop + if loop and loop.is_running(): + fut = asyncio.run_coroutine_threadsafe( + _invoke(bound_handler, parsed), loop + ) + return fut.result() + return asyncio.run(_invoke(bound_handler, parsed)) + + except Exception: + # Transient failure (I/O, handler crash, etc.) → retry + logger.exception("Message handler error; requesting retry.") + return TopicEventResponse("retry") + + return handler_fn + + close_fn = dapr_client.subscribe_with_handler( + pubsub_name=meta["pubsub"], + topic=meta["topic"], + handler_fn=_make_handler(bound), + dead_letter_topic=meta.get("dead_letter_topic"), + ) + closers.append(close_fn) + + return closers diff --git a/dapr_agents/workflow/utils/routers.py b/dapr_agents/workflow/utils/routers.py new file mode 100644 index 00000000..20a04ed2 --- /dev/null +++ b/dapr_agents/workflow/utils/routers.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import json +import logging +import types +from dataclasses import is_dataclass +from types import NoneType +from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin + +from dapr.common.pubsub.subscription import SubscriptionMessage + +from dapr_agents.types.message import EventMessageMetadata +from dapr_agents.workflow.utils.core import is_pydantic_model, is_supported_model + +logger = logging.getLogger(__name__) + + +def extract_message_models(type_hint: Any) -> list[type]: + """Normalize a message type hint into a concrete list of classes. + + Supports: + - Single class: `MyMessage` → `[MyMessage]` + - Union: `Union[Foo, Bar]` or `Foo | Bar` → `[Foo, Bar]` + - Optional: `Optional[Foo]` (i.e., `Union[Foo, None]`) → `[Foo]` + + Notes: + - Forward refs should be resolved by the caller (e.g., via `typing.get_type_hints`). + - Non-class entries (e.g., `None`, `typing.Any`) are filtered out. + - Returns an empty list when the hint isn't a usable class or union of classes. + """ + if type_hint is None: + return [] + + origin = get_origin(type_hint) + if origin in (Union, types.UnionType): # handle both `Union[...]` and `A | B` + return [ + t for t in get_args(type_hint) if t is not NoneType and isinstance(t, type) + ] + + return [type_hint] if isinstance(type_hint, type) else [] + + +def _maybe_json_loads(payload: Any, content_type: Optional[str]) -> Any: + """ + Best-effort JSON parsing based on content type and payload shape. + + - If payload is `dict`/`list` → return as-is. + - If bytes/str and content-type hints JSON (or text looks like JSON) → parse to Python. + - Otherwise → return the original payload. + + This helper is intentionally forgiving; callers should validate downstream. + """ + try: + if isinstance(payload, (dict, list)): + return payload + + ct = (content_type or "").lower() + looks_json = "json" in ct + + if isinstance(payload, bytes): + text = payload.decode("utf-8", errors="strict") + if looks_json or (text and text[0] in "{["): + return json.loads(text) + return text + + if isinstance(payload, str): + if looks_json or (payload and payload[0] in "{["): + return json.loads(payload) + return payload + + return payload + except Exception: + logger.debug("JSON parsing failed; returning raw payload", exc_info=True) + return payload + + +def extract_cloudevent_data( + message: Union[SubscriptionMessage, dict, bytes, str], +) -> Tuple[dict, dict]: + """ + Extract CloudEvent metadata and payload (attempting JSON parsing when appropriate). + + Accepts: + - `SubscriptionMessage` (Dapr SDK) + - `dict` (raw CloudEvent envelope) + - `bytes`/`str` (data-only; metadata is synthesized) + + Returns: + (event_data, metadata) as dictionaries. `event_data` may be non-dict JSON + (e.g., list) if the payload is an array; callers expecting dicts should handle it. + + Raises: + ValueError: For unsupported `message` types. + """ + if isinstance(message, SubscriptionMessage): + content_type = message.data_content_type() + raw = message.data() + event_data = _maybe_json_loads(raw, content_type) + metadata = EventMessageMetadata( + id=message.id(), + datacontenttype=content_type, + pubsubname=message.pubsub_name(), + source=message.source(), + specversion=message.spec_version(), + time=None, # not always populated by SDK + topic=message.topic(), + traceid=None, + traceparent=None, + type=message.type(), + tracestate=None, + headers=message.extensions(), + ).model_dump() + + elif isinstance(message, dict): + content_type = message.get("datacontenttype") + raw = message.get("data", {}) + event_data = _maybe_json_loads(raw, content_type) + metadata = EventMessageMetadata( + id=message.get("id"), + datacontenttype=content_type, + pubsubname=message.get("pubsubname"), + source=message.get("source"), + specversion=message.get("specversion"), + time=message.get("time"), + topic=message.get("topic"), + traceid=message.get("traceid"), + traceparent=message.get("traceparent"), + type=message.get("type"), + tracestate=message.get("tracestate"), + headers=message.get("extensions", {}), + ).model_dump() + + elif isinstance(message, (bytes, str)): + # No CloudEvent envelope; treat payload as data-only and synthesize minimal metadata. + content_type = "application/json" + event_data = _maybe_json_loads(message, content_type) + metadata = EventMessageMetadata( + id=None, + datacontenttype=content_type, + pubsubname=None, + source=None, + specversion=None, + time=None, + topic=None, + traceid=None, + traceparent=None, + type=None, + tracestate=None, + headers={}, + ).model_dump() + + else: + raise ValueError(f"Unexpected message type: {type(message)!r}") + + if not isinstance(event_data, dict): + logger.debug( + "Event data is not a dict (type=%s); value=%r", type(event_data), event_data + ) + + return event_data, metadata + + +def validate_message_model(model: Type[Any], event_data: dict) -> Any: + """ + Validate and coerce `event_data` into `model`. + + Supports: + - dict: returns `event_data` unchanged + - dataclass: constructs the dataclass + - Pydantic v2 model: uses `model_validate` + + Raises: + TypeError: If the model is not a supported kind. + ValueError: If validation/construction fails. + """ + if not is_supported_model(model): + raise TypeError(f"Unsupported model type: {model!r}") + + try: + logger.info(f"Validating payload with model '{model.__name__}'...") + + if model is dict: + return event_data + if is_dataclass(model): + return model(**event_data) + if is_pydantic_model(model): + return model.model_validate(event_data) + raise TypeError(f"Unsupported model type: {model!r}") + except Exception as e: + logger.error(f"Message validation failed for model '{model.__name__}': {e}") + raise ValueError(f"Message validation failed: {e}") + + +def parse_cloudevent( + message: Union[SubscriptionMessage, dict, bytes, str], + model: Optional[Type[Any]] = None, +) -> Tuple[Any, dict]: + """ + Parse a CloudEvent-like input and validate its payload against ``model``. + + Args: + message (Union[SubscriptionMessage, dict, bytes, str]): Incoming message; can be a Dapr ``SubscriptionMessage``, a raw + CloudEvent ``dict``, or bare ``bytes``/``str`` payloads. + model (Optional[Type[Any]]): Schema for payload validation (required). + + Returns: + Tuple[Any, dict]: A tuple containing the validated message and its metadata. + + Raises: + ValueError: If no model is provided or validation fails. + """ + try: + event_data, metadata = extract_cloudevent_data(message) + + if model is None: + raise ValueError("Message validation failed: No model provided.") + + validated_message = validate_message_model(model, event_data) + + logger.info("Message successfully parsed and validated") + logger.debug(f"Data: {validated_message}") + logger.debug(f"metadata: {metadata}") + + return validated_message, metadata + + except Exception as e: + logger.error(f"Failed to parse CloudEvent: {e}", exc_info=True) + raise ValueError(f"Invalid CloudEvent: {str(e)}") diff --git a/quickstarts/04-message-router-workflow/README.md b/quickstarts/04-message-router-workflow/README.md new file mode 100644 index 00000000..5ffd81b5 --- /dev/null +++ b/quickstarts/04-message-router-workflow/README.md @@ -0,0 +1,539 @@ +# Message Router Workflow (Pub/Sub → Workflow) + +This quickstart shows how to trigger a Dapr Workflow from a Pub/Sub message using a lightweight `@message_router` decorator. Messages are validated at the edge (Pydantic), then your handler schedules a native Dapr workflow. Activities use the `@llm_activity` decorator to offload work to an LLM. + +You’ll run two processes: + +* **App**: subscribes to a topic, routes messages, and runs the workflow runtime +* **Client**: publishes a test message to that topic + +## Prerequisites + +- Python 3.10 (recommended) +- pip package manager +- OpenAI API key +- Dapr CLI and Docker installed + +## Environment Setup + +```bash +# Create a virtual environment +python3.10 -m venv .venv + +# Activate the virtual environment +# On Windows: +.venv\Scripts\activate +# On macOS/Linux: +source .venv/bin/activate + +# Install dependencies +pip install -r requirements.txt +``` + +## Configuration + +The quickstart includes an OpenAI component configuration in the `components` directory. You have two options to configure your API key: + +### Option 1: Using Environment Variables (Recommended) + +1. Create a `.env` file in the project root and add your OpenAI API key: + +```env +OPENAI_API_KEY=your_api_key_here +``` + +2. When running the examples with Dapr, use the helper script to resolve environment variables: + +```bash +# Get the environment variables from the .env file: +export $(grep -v '^#' ../../.env | xargs) + +# Create a temporary resources folder with resolved environment variables +temp_resources_folder=$(../resolve_env_templates.py ./components) + +# Run your dapr command with the temporary resources +dapr run --app-id dapr-agent-wf --resources-path $temp_resources_folder -- python sequential_workflow.py + +# Clean up when done +rm -rf $temp_resources_folder +``` + +> The temporary resources folder will be automatically deleted when the Dapr sidecar is stopped or when the computer is restarted. + +### Option 2: Direct Component Configuration + +You can directly update the `key` in [components/openai.yaml](components/openai.yaml): +```yaml +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: openai +spec: + type: conversation.openai + metadata: + - name: key + value: "YOUR_OPENAI_API_KEY" +``` + +Replace `YOUR_OPENAI_API_KEY` with your actual OpenAI API key. + +> Many LLM providers are compatible with OpenAI's API (DeepSeek, Google AI, etc.) and can be used with this component by configuring the appropriate parameters. Dapr also has [native support](https://docs.dapr.io/reference/components-reference/supported-conversation/) for other providers like Google AI, Anthropic, Mistral, DeepSeek, etc. + +### Additional Components + +Make sure Dapr is initialized on your system: + +```bash +dapr init +``` + +The quickstart includes other necessary Dapr components in the `components` directory. For example, the workflow state store component: + +Look at the `workflowstate.yaml` file in the `components` directory: + +```yaml +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: actorStateStore + value: "true" +``` + +## Project layout + +```text +04-message-router-workflow/ +├─ components/ # Dapr components (pubsub, conversation, workflow state) +├─ app.py # Starts WorkflowRuntime + subscribes message routers +├─ handlers.py # @message_router handler (validates + schedules workflow) +├─ workflow.py # workflow & activities (decorated with @llm_activity) +└─ message_client.py # publishes a test message to the topic +``` + +## Files + +### workflow.py + +```python +from __future__ import annotations + +from dapr.ext.workflow import DaprWorkflowContext +from dotenv import load_dotenv + +from dapr_agents.llm.dapr import DaprChatClient +from dapr_agents.workflow.decorators import llm_activity + +load_dotenv() + +# Initialize the LLM client and workflow runtime +llm = DaprChatClient(component_name="openai") + + +def blog_workflow(ctx: DaprWorkflowContext, wf_input: dict) -> str: + """ + Workflow input must be JSON-serializable. We accept a dict like: + {"topic": ""} + """ + topic = wf_input["topic"] + outline = yield ctx.call_activity(create_outline, input={"topic": topic}) + post = yield ctx.call_activity(write_post, input={"outline": outline}) + return post + + +@llm_activity( + prompt="Create a short outline about {topic}. Output 3-5 bullet points.", + llm=llm, +) +async def create_outline(ctx, topic: str) -> str: + # Implemented by the decorator; body can be empty. + pass + + +@llm_activity( + prompt="Write a short blog post following this outline:\n{outline}", + llm=llm, +) +async def write_post(ctx, outline: str) -> str: + # Implemented by the decorator; body can be empty. + pass +``` + +### handlers.py + +```python +from __future__ import annotations + +import logging + +import dapr.ext.workflow as wf +from dapr.clients.grpc._response import TopicEventResponse +from pydantic import BaseModel, Field + +from dapr_agents.workflow.decorators.routers import message_router + +logger = logging.getLogger(__name__) + + +class StartBlogMessage(BaseModel): + topic: str = Field(min_length=1, description="Blog topic/title") + + +# Import the workflow after defining models to avoid circular import surprises +from workflow import blog_workflow # noqa: E402 + + +@message_router(pubsub="messagepubsub", topic="blog.requests") +def start_blog_workflow(message: StartBlogMessage) -> TopicEventResponse: + """ + Triggered by pub/sub. Validates payload via Pydantic and schedules the workflow. + """ + try: + client = wf.DaprWorkflowClient() + instance_id = client.schedule_new_workflow( + workflow=blog_workflow, + input=message.model_dump(), + ) + logger.info("Scheduled blog_workflow instance=%s topic=%s", instance_id, message.topic) + return TopicEventResponse("success") + except Exception as exc: # transient infra error → retry + logger.exception("Failed to schedule blog workflow: %s", exc) + return TopicEventResponse("retry") +``` + +### app.py + +```python +from __future__ import annotations + +import asyncio +import logging +import signal + +import dapr.ext.workflow as wf +from dapr.clients import DaprClient +from dotenv import load_dotenv +from handlers import start_blog_workflow +from workflow import ( + blog_workflow, + create_outline, + write_post, +) + +from dapr_agents.workflow.utils.registration import register_message_handlers + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def _wait_for_shutdown() -> None: + """Block until Ctrl+C or SIGTERM.""" + loop = asyncio.get_running_loop() + stop = asyncio.Event() + + def _set_stop(*_: object) -> None: + stop.set() + + try: + loop.add_signal_handler(signal.SIGINT, _set_stop) + loop.add_signal_handler(signal.SIGTERM, _set_stop) + except NotImplementedError: + # Windows fallback + signal.signal(signal.SIGINT, lambda *_: _set_stop()) + signal.signal(signal.SIGTERM, lambda *_: _set_stop()) + + await stop.wait() + + +async def main() -> None: + runtime = wf.WorkflowRuntime() + + runtime.register_workflow(blog_workflow) + runtime.register_activity(create_outline) + runtime.register_activity(write_post) + + runtime.start() + + try: + with DaprClient() as client: + # Wire streaming subscriptions for our router(s) + closers = register_message_handlers( + targets=[start_blog_workflow], + dapr_client=client, + ) + + try: + await _wait_for_shutdown() + finally: + for close in closers: + try: + close() + except Exception: + logger.exception("Error while closing subscription") + finally: + runtime.shutdown() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass +``` + +### message_client.py + +```python +from __future__ import annotations + +import asyncio +import json +import logging +import os +import random +import signal +import sys +from typing import Any, Dict + +from dapr.clients import DaprClient + +# --------------------------- +# Config via environment vars +# --------------------------- +PUBSUB_NAME = os.getenv("PUBSUB_NAME", "messagepubsub") +TOPIC_NAME = os.getenv("TOPIC_NAME", "blog.requests") +BLOG_TOPIC = os.getenv("BLOG_TOPIC", "AI Agents") # used when RAW_DATA is not provided +RAW_DATA = os.getenv("RAW_DATA") # if set, must be a JSON object (string) +CONTENT_TYPE = os.getenv("CONTENT_TYPE", "application/json") +CE_TYPE = os.getenv("CLOUDEVENT_TYPE") # optional CloudEvent 'type' metadata + +# Publish behavior +PUBLISH_ONCE = os.getenv("PUBLISH_ONCE", "true").lower() in {"1", "true", "yes"} +INTERVAL_SEC = float(os.getenv("INTERVAL_SEC", "0")) # used when PUBLISH_ONCE=false +MAX_ATTEMPTS = int(os.getenv("MAX_ATTEMPTS", "8")) +INITIAL_DELAY = float(os.getenv("INITIAL_DELAY", "0.5")) +BACKOFF_FACTOR = float(os.getenv("BACKOFF_FACTOR", "2.0")) +JITTER_FRAC = float(os.getenv("JITTER_FRAC", "0.2")) + +# Optional warmup (give sidecar/broker a moment) +STARTUP_DELAY = float(os.getenv("STARTUP_DELAY", "1.0")) + +logger = logging.getLogger("publisher") + + +async def _backoff_sleep(delay: float, jitter: float, factor: float) -> float: + """Sleep for ~delay seconds with ±jitter% randomness, then return the next delay.""" + actual = max(0.0, delay * (1 + random.uniform(-jitter, jitter))) + if actual: + await asyncio.sleep(actual) + return delay * factor + + +def _build_payload() -> Dict[str, Any]: + """ + Build the JSON payload: + - if RAW_DATA is set → parse as JSON (must be an object) + - else → {"topic": BLOG_TOPIC} + """ + if RAW_DATA: + try: + data = json.loads(RAW_DATA) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Invalid RAW_DATA JSON: {exc}") from exc + if not isinstance(data, dict): + raise ValueError("RAW_DATA must be a JSON object") + return data + + return {"topic": BLOG_TOPIC} + + +def _encode_payload(payload: Dict[str, Any]) -> bytes: + """Encode the payload as UTF-8 JSON bytes.""" + return json.dumps(payload, ensure_ascii=False).encode("utf-8") + + +async def publish_once(client: DaprClient, payload: Dict[str, Any]) -> None: + """Publish once with retries and exponential backoff.""" + delay = INITIAL_DELAY + body = _encode_payload(payload) + + for attempt in range(1, MAX_ATTEMPTS + 1): + try: + logger.info("publish attempt %d → %s/%s", attempt, PUBSUB_NAME, TOPIC_NAME) + client.publish_event( + pubsub_name=PUBSUB_NAME, + topic_name=TOPIC_NAME, + data=body, + data_content_type=CONTENT_TYPE, + publish_metadata=({"cloudevent.type": CE_TYPE} if CE_TYPE else None), + ) + logger.info("published successfully") + return + except Exception as exc: # noqa: BLE001 + logger.warning("publish failed: %s", exc) + if attempt == MAX_ATTEMPTS: + raise + logger.info("retrying in ~%.2fs …", delay) + delay = await _backoff_sleep(delay, JITTER_FRAC, BACKOFF_FACTOR) + + +async def main() -> int: + logging.basicConfig(level=logging.INFO) + stop_event = asyncio.Event() + + # Signal-aware shutdown + loop = asyncio.get_running_loop() + + def _stop(*_: object) -> None: + stop_event.set() + + try: + loop.add_signal_handler(signal.SIGINT, _stop) + loop.add_signal_handler(signal.SIGTERM, _stop) + except NotImplementedError: + signal.signal(signal.SIGINT, lambda *_: _stop()) + signal.signal(signal.SIGTERM, lambda *_: _stop()) + + # Optional warmup + if STARTUP_DELAY > 0: + await asyncio.sleep(STARTUP_DELAY) + + payload = _build_payload() + logger.info("payload: %s", payload) + + try: + with DaprClient() as client: + if PUBLISH_ONCE: + await publish_once(client, payload) + # brief wait so logs flush nicely under dapr + await asyncio.sleep(0.2) + return 0 + + # periodic mode + if INTERVAL_SEC <= 0: + logger.error("INTERVAL_SEC must be > 0 when PUBLISH_ONCE=false") + return 2 + + logger.info("starting periodic publisher every %.2fs", INTERVAL_SEC) + while not stop_event.is_set(): + try: + await publish_once(client, payload) + except Exception as exc: # noqa: BLE001 + logger.error("giving up after %d attempts: %s", MAX_ATTEMPTS, exc) + + # wait for next tick or shutdown + try: + await asyncio.wait_for(stop_event.wait(), timeout=INTERVAL_SEC) + except asyncio.TimeoutError: + pass + + logger.info("shutdown requested; exiting") + return 0 + + except KeyboardInterrupt: + return 130 + except Exception as exc: # noqa: BLE001 + logger.exception("fatal error: %s", exc) + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) +``` + +## How it works (flow) + +* `message_client.py` publishes a CloudEvent-style JSON payload to `topic=blog.requests` on pubsub=messagepubsub. +* `app.py` starts the Dapr Workflow runtime, registers `blog_workflow` + `activitie`s, and subscribes your `@message_router` handler. +* `handlers.py` receives and validates the message (`StartBlogMessage` via Pydantic), then calls `DaprWorkflowClient().schedule_new_workflow(...)`. +* `workflow.py` runs `blog_workflow`, calling two LLM-backed activities (`create_outline`, `write_post`) via `@llm_activity`. + +## Running + +Start the app (subscriber + workflow runtime) + +```bash +rendered_components=$(../resolve_env_templates.py ./components) +dapr run \ + --app-id message-workflow \ + --resources-path "$rendered_components" \ + -- python app.py +rm -rf "$rendered_components" +``` + +Publish a test message (publisher) + +```bash +rendered_components=$(../resolve_env_templates.py ./components) +dapr run \ + --app-id message-workflow-client \ + --resources-path "$rendered_components" \ + -- python message_client.py +rm -rf "$rendered_components" +``` + +## Publisher configuration (env vars) + +You can tweak message_client.py using environment variables: + +| Variable | Default | Description | +| ----------------- | ------------------ | --------------------------------------------------- | +| `PUBSUB_NAME` | `messagepubsub` | Pub/Sub component name | +| `TOPIC_NAME` | `blog.requests` | Topic to publish to | +| `BLOG_TOPIC` | `AI Agents` | Fallback payload: `{"topic": BLOG_TOPIC}` | +| `RAW_DATA` | *(unset)* | JSON string that overrides payload (must be object) | +| `CONTENT_TYPE` | `application/json` | Content type sent with the event | +| `CLOUDEVENT_TYPE` | *(unset)* | Optional `cloudevent.type` metadata | +| `PUBLISH_ONCE` | `true` | If `false`, publish periodically | +| `INTERVAL_SEC` | `0` | Period (seconds) when `PUBLISH_ONCE=false` | +| `MAX_ATTEMPTS` | `8` | Retry attempts per publish | +| `INITIAL_DELAY` | `0.5` | Initial backoff seconds | +| `BACKOFF_FACTOR` | `2.0` | Exponential backoff factor | +| `JITTER_FRAC` | `0.2` | ± jitter applied to each delay | +| `STARTUP_DELAY` | `1.0` | Sleep before first publish (sidecar warmup) | + +### Additional Examples + +Publish a custom topic via env: + +```bash +BLOG_TOPIC="Serverless Agents" dapr run \ + --app-id message-workflow-client \ + --resources-path "$rendered_components" \ + -- python message_client.py +``` + +Publish a raw JSON object: + +```bash +RAW_DATA='{"topic":"Thoughtful Orchestration"}' dapr run \ + --app-id message-workflow-client \ + --resources-path "$rendered_components" \ + -- python message_client.py +``` + +## Integration with Dapr + +Dapr Agents workflows leverage Dapr's core capabilities: + +- **Durability**: Workflows survive process restarts or crashes +- **State Management**: Workflow state is persisted in a distributed state store +- **Actor Model**: Tasks run as reliable, stateful actors within the workflow +- **Event Handling**: Workflows can react to external events + +## Troubleshooting + +1. **Docker is Running**: Ensure Docker is running with `docker ps` and verify you have container instances with `daprio/dapr`, `openzipkin/zipkin`, and `redis` images running +2. **Redis Connection**: Ensure Redis is running (automatically installed by Dapr) +3. **Dapr Initialization**: If components aren't found, verify Dapr is initialized with `dapr init` +4. **API Key**: Check your OpenAI API key if authentication fails diff --git a/quickstarts/04-message-router-workflow/app.py b/quickstarts/04-message-router-workflow/app.py new file mode 100644 index 00000000..3bd53e73 --- /dev/null +++ b/quickstarts/04-message-router-workflow/app.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import asyncio +import logging +import signal + +import dapr.ext.workflow as wf +from dapr.clients import DaprClient +from dotenv import load_dotenv +from handlers import start_blog_workflow +from workflow import ( + blog_workflow, + create_outline, + write_post, +) + +from dapr_agents.workflow.utils.registration import register_message_handlers + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def _wait_for_shutdown() -> None: + """Block until Ctrl+C or SIGTERM.""" + loop = asyncio.get_running_loop() + stop = asyncio.Event() + + def _set_stop(*_: object) -> None: + stop.set() + + try: + loop.add_signal_handler(signal.SIGINT, _set_stop) + loop.add_signal_handler(signal.SIGTERM, _set_stop) + except NotImplementedError: + # Windows fallback + signal.signal(signal.SIGINT, lambda *_: _set_stop()) + signal.signal(signal.SIGTERM, lambda *_: _set_stop()) + + await stop.wait() + + +async def main() -> None: + runtime = wf.WorkflowRuntime() + + runtime.register_workflow(blog_workflow) + runtime.register_activity(create_outline) + runtime.register_activity(write_post) + + runtime.start() + + try: + with DaprClient() as client: + # Wire streaming subscriptions for our router(s) + closers = register_message_handlers( + targets=[start_blog_workflow], + dapr_client=client, + ) + + try: + await _wait_for_shutdown() + finally: + for close in closers: + try: + close() + except Exception: + logger.exception("Error while closing subscription") + finally: + runtime.shutdown() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/quickstarts/04-message-router-workflow/components/openai.yaml b/quickstarts/04-message-router-workflow/components/openai.yaml new file mode 100644 index 00000000..7c518fb2 --- /dev/null +++ b/quickstarts/04-message-router-workflow/components/openai.yaml @@ -0,0 +1,14 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: openai +spec: + type: conversation.openai + version: v1 + metadata: + - name: key + value: "{{OPENAI_API_KEY}}" + - name: model + value: gpt-5-mini + - name: temperature + value: 1 diff --git a/quickstarts/04-message-router-workflow/components/pubsub.yaml b/quickstarts/04-message-router-workflow/components/pubsub.yaml new file mode 100644 index 00000000..6bd05aca --- /dev/null +++ b/quickstarts/04-message-router-workflow/components/pubsub.yaml @@ -0,0 +1,10 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: messagepubsub +spec: + type: pubsub.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 diff --git a/quickstarts/04-message-router-workflow/components/statestore.yaml b/quickstarts/04-message-router-workflow/components/statestore.yaml new file mode 100644 index 00000000..2fc32cd0 --- /dev/null +++ b/quickstarts/04-message-router-workflow/components/statestore.yaml @@ -0,0 +1,12 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: workflowstatestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: actorStateStore + value: "true" \ No newline at end of file diff --git a/quickstarts/04-message-router-workflow/handlers.py b/quickstarts/04-message-router-workflow/handlers.py new file mode 100644 index 00000000..7a6e1baf --- /dev/null +++ b/quickstarts/04-message-router-workflow/handlers.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import logging + +import dapr.ext.workflow as wf +from dapr.clients.grpc._response import TopicEventResponse +from pydantic import BaseModel, Field + +from dapr_agents.workflow.decorators.routers import message_router + +logger = logging.getLogger(__name__) + + +class StartBlogMessage(BaseModel): + topic: str = Field(min_length=1, description="Blog topic/title") + + +# Import the workflow after defining models to avoid circular import surprises +from workflow import blog_workflow # noqa: E402 + + +@message_router(pubsub="messagepubsub", topic="blog.requests") +def start_blog_workflow(message: StartBlogMessage) -> TopicEventResponse: + """ + Triggered by pub/sub. Validates payload via Pydantic and schedules the workflow. + """ + try: + client = wf.DaprWorkflowClient() + instance_id = client.schedule_new_workflow( + workflow=blog_workflow, + input=message.model_dump(), + ) + logger.info( + "Scheduled blog_workflow instance=%s topic=%s", instance_id, message.topic + ) + return TopicEventResponse("success") + except Exception as exc: # transient infra error → retry + logger.exception("Failed to schedule blog workflow: %s", exc) + return TopicEventResponse("retry") diff --git a/quickstarts/04-message-router-workflow/message_client.py b/quickstarts/04-message-router-workflow/message_client.py new file mode 100644 index 00000000..fdbd39d7 --- /dev/null +++ b/quickstarts/04-message-router-workflow/message_client.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import random +import signal +import sys +from typing import Any, Dict + +from dapr.clients import DaprClient + +# --------------------------- +# Config via environment vars +# --------------------------- +PUBSUB_NAME = os.getenv("PUBSUB_NAME", "messagepubsub") +TOPIC_NAME = os.getenv("TOPIC_NAME", "blog.requests") +BLOG_TOPIC = os.getenv("BLOG_TOPIC", "AI Agents") # used when RAW_DATA is not provided +RAW_DATA = os.getenv("RAW_DATA") # if set, must be a JSON object (string) +CONTENT_TYPE = os.getenv("CONTENT_TYPE", "application/json") +CE_TYPE = os.getenv("CLOUDEVENT_TYPE") # optional CloudEvent 'type' metadata + +# Publish behavior +PUBLISH_ONCE = os.getenv("PUBLISH_ONCE", "true").lower() in {"1", "true", "yes"} +INTERVAL_SEC = float(os.getenv("INTERVAL_SEC", "0")) # used when PUBLISH_ONCE=false +MAX_ATTEMPTS = int(os.getenv("MAX_ATTEMPTS", "8")) +INITIAL_DELAY = float(os.getenv("INITIAL_DELAY", "0.5")) +BACKOFF_FACTOR = float(os.getenv("BACKOFF_FACTOR", "2.0")) +JITTER_FRAC = float(os.getenv("JITTER_FRAC", "0.2")) + +# Optional warmup (give sidecar/broker a moment) +STARTUP_DELAY = float(os.getenv("STARTUP_DELAY", "1.0")) + +logger = logging.getLogger("publisher") + + +async def _backoff_sleep(delay: float, jitter: float, factor: float) -> float: + """Sleep for ~delay seconds with ±jitter% randomness, then return the next delay.""" + actual = max(0.0, delay * (1 + random.uniform(-jitter, jitter))) + if actual: + await asyncio.sleep(actual) + return delay * factor + + +def _build_payload() -> Dict[str, Any]: + """ + Build the JSON payload: + - if RAW_DATA is set → parse as JSON (must be an object) + - else → {"topic": BLOG_TOPIC} + """ + if RAW_DATA: + try: + data = json.loads(RAW_DATA) + except Exception as exc: # noqa: BLE001 + raise ValueError(f"Invalid RAW_DATA JSON: {exc}") from exc + if not isinstance(data, dict): + raise ValueError("RAW_DATA must be a JSON object") + return data + + return {"topic": BLOG_TOPIC} + + +def _encode_payload(payload: Dict[str, Any]) -> bytes: + """Encode the payload as UTF-8 JSON bytes.""" + return json.dumps(payload, ensure_ascii=False).encode("utf-8") + + +async def publish_once(client: DaprClient, payload: Dict[str, Any]) -> None: + """Publish once with retries and exponential backoff.""" + delay = INITIAL_DELAY + body = _encode_payload(payload) + + for attempt in range(1, MAX_ATTEMPTS + 1): + try: + logger.info("publish attempt %d → %s/%s", attempt, PUBSUB_NAME, TOPIC_NAME) + client.publish_event( + pubsub_name=PUBSUB_NAME, + topic_name=TOPIC_NAME, + data=body, + data_content_type=CONTENT_TYPE, + publish_metadata=({"cloudevent.type": CE_TYPE} if CE_TYPE else None), + ) + logger.info("published successfully") + return + except Exception as exc: # noqa: BLE001 + logger.warning("publish failed: %s", exc) + if attempt == MAX_ATTEMPTS: + raise + logger.info("retrying in ~%.2fs …", delay) + delay = await _backoff_sleep(delay, JITTER_FRAC, BACKOFF_FACTOR) + + +async def main() -> int: + logging.basicConfig(level=logging.INFO) + stop_event = asyncio.Event() + + # Signal-aware shutdown + loop = asyncio.get_running_loop() + + def _stop(*_: object) -> None: + stop_event.set() + + try: + loop.add_signal_handler(signal.SIGINT, _stop) + loop.add_signal_handler(signal.SIGTERM, _stop) + except NotImplementedError: + signal.signal(signal.SIGINT, lambda *_: _stop()) + signal.signal(signal.SIGTERM, lambda *_: _stop()) + + # Optional warmup + if STARTUP_DELAY > 0: + await asyncio.sleep(STARTUP_DELAY) + + payload = _build_payload() + logger.info("payload: %s", payload) + + try: + with DaprClient() as client: + if PUBLISH_ONCE: + await publish_once(client, payload) + # brief wait so logs flush nicely under dapr + await asyncio.sleep(0.2) + return 0 + + # periodic mode + if INTERVAL_SEC <= 0: + logger.error("INTERVAL_SEC must be > 0 when PUBLISH_ONCE=false") + return 2 + + logger.info("starting periodic publisher every %.2fs", INTERVAL_SEC) + while not stop_event.is_set(): + try: + await publish_once(client, payload) + except Exception as exc: # noqa: BLE001 + logger.error("giving up after %d attempts: %s", MAX_ATTEMPTS, exc) + + # wait for next tick or shutdown + try: + await asyncio.wait_for(stop_event.wait(), timeout=INTERVAL_SEC) + except asyncio.TimeoutError: + pass + + logger.info("shutdown requested; exiting") + return 0 + + except KeyboardInterrupt: + return 130 + except Exception as exc: # noqa: BLE001 + logger.exception("fatal error: %s", exc) + return 1 + + +if __name__ == "__main__": + sys.exit(asyncio.run(main())) diff --git a/quickstarts/04-message-router-workflow/workflow.py b/quickstarts/04-message-router-workflow/workflow.py new file mode 100644 index 00000000..86fcb5af --- /dev/null +++ b/quickstarts/04-message-router-workflow/workflow.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dapr.ext.workflow import DaprWorkflowContext +from dotenv import load_dotenv + +from dapr_agents.llm.dapr import DaprChatClient +from dapr_agents.workflow.decorators import llm_activity + +load_dotenv() + +# Initialize the LLM client and workflow runtime +llm = DaprChatClient(component_name="openai") + + +def blog_workflow(ctx: DaprWorkflowContext, wf_input: dict) -> str: + """ + Workflow input must be JSON-serializable. We accept a dict like: + {"topic": ""} + """ + topic = wf_input["topic"] + outline = yield ctx.call_activity(create_outline, input={"topic": topic}) + post = yield ctx.call_activity(write_post, input={"outline": outline}) + return post + + +@llm_activity( + prompt="Create a short outline about {topic}. Output 3-5 bullet points.", + llm=llm, +) +async def create_outline(ctx, topic: str) -> str: + # Implemented by the decorator; body can be empty. + pass + + +@llm_activity( + prompt="Write a short blog post following this outline:\n{outline}", + llm=llm, +) +async def write_post(ctx, outline: str) -> str: + # Implemented by the decorator; body can be empty. + pass diff --git a/quickstarts/README.md b/quickstarts/README.md index f2d26034..347adbd1 100644 --- a/quickstarts/README.md +++ b/quickstarts/README.md @@ -119,6 +119,19 @@ This quickstart demonstrates how to design and run **agent-based workflows**, st [Go to Agent-based Workflow Patterns](./04-agent-based-workflows/) +### Message Router Workflow + +Learn how to trigger Dapr Workflows via Pub/Sub messages using the `@message_router` decorator. +This pattern connects event-driven systems with LLM-powered workflows, validating and routing structured messages to durable workflow executions. + +- **Event-Driven Orchestration**: Start workflows automatically when messages arrive on a topic +- **Edge Validation**: Enforce schema integrity with Pydantic before invoking workflows +- **Seamless Integration**: Combine Dapr Pub/Sub, Workflow Runtime, and LLM activities for resilient automation + +This quickstart demonstrates how to design a message-driven workflow where each published event triggers a workflow instance such as creating a blog post outline and draft powered by an LLM. + +[Go to Message Router Workflow](./04-message-router-workflow/) + ### Multi-Agent Workflows Advanced example of event-driven workflows with multiple autonomous agents: diff --git a/tests/workflow/test_message_router.py b/tests/workflow/test_message_router.py new file mode 100644 index 00000000..af876c42 --- /dev/null +++ b/tests/workflow/test_message_router.py @@ -0,0 +1,663 @@ +import pytest +from typing import Union, Optional +from dataclasses import dataclass +from unittest.mock import MagicMock +from pydantic import BaseModel, Field + +from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.workflow.utils.routers import ( + extract_message_models, + extract_cloudevent_data, + validate_message_model, + parse_cloudevent, +) +from dapr_agents.workflow.utils.registration import register_message_handlers + + +# Test Models +class OrderCreated(BaseModel): + """Test Pydantic model for order creation events.""" + + order_id: str = Field(..., description="Unique order identifier") + amount: float = Field(..., description="Order amount") + customer: str = Field(..., description="Customer name") + + +class OrderCancelled(BaseModel): + """Test Pydantic model for order cancellation events.""" + + order_id: str = Field(..., description="Order ID to cancel") + reason: str = Field(..., description="Cancellation reason") + + +@dataclass +class ShipmentCreated: + """Test dataclass for shipment events.""" + + shipment_id: str + order_id: str + carrier: str + + +# Tests for extract_message_models utility + + +def test_extract_message_models_single_class(): + """Test extracting a single model class.""" + models = extract_message_models(OrderCreated) + assert models == [OrderCreated] + + +def test_extract_message_models_union(): + """Test extracting models from Union type hint.""" + models = extract_message_models(Union[OrderCreated, OrderCancelled]) + assert set(models) == {OrderCreated, OrderCancelled} + + +def test_extract_message_models_optional(): + """Test extracting models from Optional type hint (filters out None).""" + models = extract_message_models(Optional[OrderCreated]) + assert models == [OrderCreated] + + +def test_extract_message_models_pipe_union(): + """Test extracting models from pipe union syntax (Python 3.10+).""" + # Note: This test requires Python 3.10+ for the | syntax + try: + hint = eval("OrderCreated | OrderCancelled") + models = extract_message_models(hint) + assert set(models) == {OrderCreated, OrderCancelled} + except SyntaxError: + pytest.skip("Python 3.10+ required for pipe union syntax") + + +def test_extract_message_models_none_input(): + """Test extracting models from None returns empty list.""" + models = extract_message_models(None) + assert models == [] + + +def test_extract_message_models_non_class(): + """Test extracting models from non-class type returns empty list.""" + models = extract_message_models("not a class") + assert models == [] + + +# Tests for message_router decorator + + +def test_message_router_requires_pubsub(): + """Test that message_router raises ValueError when pubsub is missing.""" + with pytest.raises( + ValueError, + match="`pubsub` and `topic` are required when using @message_router with arguments", + ): + + @message_router(topic="orders") + def handler(message: OrderCreated): + pass + + +def test_message_router_requires_topic(): + """Test that message_router raises ValueError when topic is missing.""" + with pytest.raises( + ValueError, + match="`pubsub` and `topic` are required when using @message_router with arguments", + ): + + @message_router(pubsub="messagepubsub") + def handler(message: OrderCreated): + pass + + +def test_message_router_requires_message_parameter(): + """Test that message_router raises ValueError when 'message' parameter is missing.""" + with pytest.raises(ValueError, match="must have a 'message' parameter"): + + @message_router(pubsub="messagepubsub", topic="orders") + def handler(data: OrderCreated): # Wrong parameter name + pass + + +def test_message_router_requires_type_hint(): + """Test that message_router raises TypeError when message parameter has no type hint.""" + with pytest.raises(TypeError, match="must type-hint the 'message' parameter"): + + @message_router(pubsub="messagepubsub", topic="orders") + def handler(message): # No type hint + pass + + +def test_message_router_unsupported_type(): + """Test that message_router raises TypeError for unsupported message types.""" + with pytest.raises(TypeError, match="Unsupported model type"): + + @message_router(pubsub="messagepubsub", topic="orders") + def handler(message: str): # str is not a supported model + pass + + +def test_message_router_basic_decoration(): + """Test basic message_router decoration with single model.""" + + @message_router(pubsub="messagepubsub", topic="orders.created") + def handle_order(message: OrderCreated): + return message.order_id + + # Check metadata attributes + assert hasattr(handle_order, "_is_message_handler") + assert handle_order._is_message_handler is True + assert hasattr(handle_order, "_message_router_data") + + data = handle_order._message_router_data + assert data["pubsub"] == "messagepubsub" + assert data["topic"] == "orders.created" + assert data["dead_letter_topic"] == "orders.created_DEAD" + assert data["is_broadcast"] is False + assert OrderCreated in data["message_schemas"] + assert "OrderCreated" in data["message_types"] + + +def test_message_router_with_dead_letter_topic(): + """Test message_router with custom dead letter topic.""" + + @message_router( + pubsub="messagepubsub", + topic="orders.created", + dead_letter_topic="orders.failed", + ) + def handle_order(message: OrderCreated): + pass + + data = handle_order._message_router_data + assert data["dead_letter_topic"] == "orders.failed" + + +def test_message_router_with_broadcast(): + """Test message_router with broadcast flag.""" + + @message_router(pubsub="messagepubsub", topic="notifications", broadcast=True) + def handle_notification(message: OrderCreated): + pass + + data = handle_notification._message_router_data + assert data["is_broadcast"] is True + + +def test_message_router_union_types(): + """Test message_router with Union of multiple message types.""" + + @message_router(pubsub="messagepubsub", topic="order.events") + def handle_order_event(message: Union[OrderCreated, OrderCancelled]): + pass + + data = handle_order_event._message_router_data + assert set(data["message_schemas"]) == {OrderCreated, OrderCancelled} + assert set(data["message_types"]) == {"OrderCreated", "OrderCancelled"} + + +def test_message_router_dataclass_model(): + """Test message_router with dataclass model.""" + + @message_router(pubsub="messagepubsub", topic="shipments") + def handle_shipment(message: ShipmentCreated): + pass + + data = handle_shipment._message_router_data + assert ShipmentCreated in data["message_schemas"] + assert "ShipmentCreated" in data["message_types"] + + +def test_message_router_preserves_function_metadata(): + """Test that message_router preserves function name and docstring.""" + + @message_router(pubsub="messagepubsub", topic="orders") + def my_handler(message: OrderCreated): + """Handler for order created events.""" + return "processed" + + assert my_handler.__name__ == "my_handler" + assert my_handler.__doc__ == "Handler for order created events." + + +def test_message_router_function_still_callable(): + """Test that decorated function is still callable.""" + + @message_router(pubsub="messagepubsub", topic="orders") + def handle_order(message: OrderCreated): + return f"Processed order {message.order_id}" + + # Function should still be callable with the right arguments + test_order = OrderCreated(order_id="123", amount=99.99, customer="Alice") + result = handle_order(test_order) + assert result == "Processed order 123" + + +# Tests for validate_message_model utility + + +def test_validate_message_model_pydantic(): + """Test validating data against Pydantic model.""" + event_data = {"order_id": "123", "amount": 99.99, "customer": "Alice"} + result = validate_message_model(OrderCreated, event_data) + + assert isinstance(result, OrderCreated) + assert result.order_id == "123" + assert result.amount == 99.99 + assert result.customer == "Alice" + + +def test_validate_message_model_dataclass(): + """Test validating data against dataclass model.""" + event_data = {"shipment_id": "S123", "order_id": "O456", "carrier": "FedEx"} + result = validate_message_model(ShipmentCreated, event_data) + + assert isinstance(result, ShipmentCreated) + assert result.shipment_id == "S123" + assert result.order_id == "O456" + assert result.carrier == "FedEx" + + +def test_validate_message_model_dict(): + """Test validating data against dict model (passthrough).""" + event_data = {"key": "value", "number": 42} + result = validate_message_model(dict, event_data) + + assert result == event_data + assert isinstance(result, dict) + + +def test_validate_message_model_validation_error(): + """Test that validation errors are raised properly.""" + # Missing required field + event_data = {"order_id": "123"} # Missing 'amount' and 'customer' + + with pytest.raises(ValueError, match="Message validation failed"): + validate_message_model(OrderCreated, event_data) + + +def test_validate_message_model_unsupported_type(): + """Test that unsupported model types raise TypeError.""" + + class UnsupportedModel: + pass + + with pytest.raises(TypeError, match="Unsupported model type"): + validate_message_model(UnsupportedModel, {}) + + +# Tests for extract_cloudevent_data utility + + +def test_extract_cloudevent_data_from_dict(): + """Test extracting CloudEvent data from dict envelope.""" + message = { + "id": "event-123", + "source": "order-service", + "type": "order.created", + "datacontenttype": "application/json", + "data": {"order_id": "123", "amount": 99.99, "customer": "Alice"}, + "topic": "orders", + "pubsubname": "messagepubsub", + "specversion": "1.0", + } + + event_data, metadata = extract_cloudevent_data(message) + + assert event_data == {"order_id": "123", "amount": 99.99, "customer": "Alice"} + assert metadata["id"] == "event-123" + assert metadata["source"] == "order-service" + assert metadata["type"] == "order.created" + assert metadata["topic"] == "orders" + assert metadata["pubsubname"] == "messagepubsub" + + +def test_extract_cloudevent_data_from_dict_already_parsed(): + """Test extracting CloudEvent when data is already a dict.""" + message = { + "id": "event-123", + "data": {"key": "value"}, # Already a dict + "datacontenttype": "application/json", + } + + event_data, metadata = extract_cloudevent_data(message) + assert event_data == {"key": "value"} + + +def test_extract_cloudevent_data_from_bytes(): + """Test extracting CloudEvent data from bytes payload.""" + import json + + payload = json.dumps({"order_id": "123", "amount": 99.99}).encode("utf-8") + event_data, metadata = extract_cloudevent_data(payload) + + assert event_data == {"order_id": "123", "amount": 99.99} + assert metadata["datacontenttype"] == "application/json" + + +def test_extract_cloudevent_data_from_str(): + """Test extracting CloudEvent data from string payload.""" + import json + + payload = json.dumps({"order_id": "123", "amount": 99.99}) + event_data, metadata = extract_cloudevent_data(payload) + + assert event_data == {"order_id": "123", "amount": 99.99} + assert metadata["datacontenttype"] == "application/json" + + +def test_extract_cloudevent_data_from_subscription_message(): + """Test extracting CloudEvent from Dapr SubscriptionMessage.""" + import json + from unittest.mock import MagicMock as MockClass + + mock_message = MockClass() + mock_message.id.return_value = "event-456" + mock_message.source.return_value = "test-service" + mock_message.type.return_value = "test.event" + mock_message.data_content_type.return_value = "application/json" + mock_message.data.return_value = json.dumps({"key": "value"}).encode("utf-8") + mock_message.topic.return_value = "test-topic" + mock_message.pubsub_name.return_value = "test-pubsub" + mock_message.spec_version.return_value = "1.0" + mock_message.extensions.return_value = {} + + event_data, metadata = extract_cloudevent_data(mock_message) + + assert event_data == {"key": "value"} + assert metadata["id"] == "event-456" + assert metadata["source"] == "test-service" + assert metadata["topic"] == "test-topic" + + +def test_extract_cloudevent_data_unsupported_type(): + """Test that unsupported message types raise ValueError.""" + with pytest.raises(ValueError, match="Unexpected message type"): + extract_cloudevent_data(12345) # int is not supported + + +def test_extract_cloudevent_data_non_dict_data(): + """Test handling non-dict event data (e.g., array).""" + message = { + "id": "event-123", + "data": [1, 2, 3], # Array data + "datacontenttype": "application/json", + } + + event_data, metadata = extract_cloudevent_data(message) + assert event_data == [1, 2, 3] + assert isinstance(event_data, list) + + +# Tests for parse_cloudevent utility + + +def test_parse_cloudevent_with_pydantic_model(): + """Test parsing CloudEvent with Pydantic model validation.""" + message = { + "id": "event-123", + "data": {"order_id": "123", "amount": 99.99, "customer": "Alice"}, + "datacontenttype": "application/json", + } + + validated, metadata = parse_cloudevent(message, model=OrderCreated) + + assert isinstance(validated, OrderCreated) + assert validated.order_id == "123" + assert validated.amount == 99.99 + assert metadata["id"] == "event-123" + + +def test_parse_cloudevent_with_dataclass_model(): + """Test parsing CloudEvent with dataclass model.""" + message = { + "id": "event-456", + "data": {"shipment_id": "S123", "order_id": "O456", "carrier": "FedEx"}, + } + + validated, metadata = parse_cloudevent(message, model=ShipmentCreated) + + assert isinstance(validated, ShipmentCreated) + assert validated.shipment_id == "S123" + assert validated.carrier == "FedEx" + + +def test_parse_cloudevent_with_dict_model(): + """Test parsing CloudEvent with dict model (no validation).""" + message = { + "id": "event-789", + "data": {"arbitrary": "data", "number": 42}, + } + + validated, metadata = parse_cloudevent(message, model=dict) + + assert validated == {"arbitrary": "data", "number": 42} + assert isinstance(validated, dict) + + +def test_parse_cloudevent_without_model(): + """Test that parsing without model raises ValueError.""" + message = {"id": "event-123", "data": {"key": "value"}} + + with pytest.raises(ValueError, match="No model provided"): + parse_cloudevent(message, model=None) + + +def test_parse_cloudevent_validation_failure(): + """Test that validation failures are properly raised.""" + message = { + "id": "event-123", + "data": {"order_id": "123"}, # Missing required fields + } + + with pytest.raises(ValueError, match="Invalid CloudEvent"): + parse_cloudevent(message, model=OrderCreated) + + +def test_parse_cloudevent_from_bytes(): + """Test parsing CloudEvent from bytes payload.""" + import json + + payload = json.dumps( + {"order_id": "123", "amount": 99.99, "customer": "Bob"} + ).encode("utf-8") + + validated, metadata = parse_cloudevent(payload, model=OrderCreated) + + assert isinstance(validated, OrderCreated) + assert validated.order_id == "123" + assert validated.customer == "Bob" + + +# Integration tests + + +def test_message_router_end_to_end(): + """Test complete flow from decoration to execution with validation.""" + + results = [] + + @message_router(pubsub="messagepubsub", topic="orders.created") + def handle_order(message: OrderCreated): + results.append(message) + return "success" + + # Verify decoration + assert hasattr(handle_order, "_is_message_handler") + assert handle_order._is_message_handler is True + + # Simulate execution + test_order = OrderCreated(order_id="999", amount=199.99, customer="Charlie") + result = handle_order(test_order) + + assert result == "success" + assert len(results) == 1 + assert results[0].order_id == "999" + + +def test_message_router_multiple_handlers(): + """Test multiple handlers can be decorated independently.""" + + @message_router(pubsub="messagepubsub", topic="orders.created") + def handle_order_created(message: OrderCreated): + return "order_created" + + @message_router(pubsub="messagepubsub", topic="orders.cancelled") + def handle_order_cancelled(message: OrderCancelled): + return "order_cancelled" + + # Both should have independent metadata + assert handle_order_created._message_router_data["topic"] == "orders.created" + assert handle_order_cancelled._message_router_data["topic"] == "orders.cancelled" + assert ( + handle_order_created._message_router_data["message_schemas"][0] == OrderCreated + ) + assert ( + handle_order_cancelled._message_router_data["message_schemas"][0] + == OrderCancelled + ) + + +def test_message_router_with_class_method(): + """Test message_router can be used with class methods.""" + + class OrderHandler: + def __init__(self): + self.processed = [] + + @message_router(pubsub="messagepubsub", topic="orders") + def handle(self, message: OrderCreated): + self.processed.append(message.order_id) + return "processed" + + handler = OrderHandler() + test_order = OrderCreated(order_id="888", amount=88.88, customer="Diana") + + result = handler.handle(test_order) + + assert result == "processed" + assert "888" in handler.processed + assert hasattr(handler.handle, "_is_message_handler") + + +# Tests for register_message_handlers + + +def test_register_message_handlers_discovers_standalone_function(): + """Test that standalone decorated functions are discovered.""" + mock_client = MagicMock() + mock_client.subscribe_with_handler.return_value = MagicMock() + + @message_router(pubsub="messagepubsub", topic="orders") + def handle_order(message: OrderCreated): + return "success" + + closers = register_message_handlers([handle_order], mock_client, loop=None) + + # Should create one subscription + assert mock_client.subscribe_with_handler.call_count == 1 + assert len(closers) == 1 + + # Verify subscription parameters + call_args = mock_client.subscribe_with_handler.call_args + assert call_args.kwargs["pubsub_name"] == "messagepubsub" + assert call_args.kwargs["topic"] == "orders" + assert call_args.kwargs["dead_letter_topic"] == "orders_DEAD" + + +def test_register_message_handlers_discovers_class_methods(): + """Test that decorated methods in class instances are discovered.""" + mock_client = MagicMock() + mock_client.subscribe_with_handler.return_value = MagicMock() + + class OrderHandler: + @message_router(pubsub="messagepubsub", topic="orders.created") + def handle_created(self, message: OrderCreated): + return "created" + + @message_router(pubsub="messagepubsub", topic="orders.cancelled") + def handle_cancelled(self, message: OrderCancelled): + return "cancelled" + + handler = OrderHandler() + closers = register_message_handlers([handler], mock_client, loop=None) + + # Should create two subscriptions + assert mock_client.subscribe_with_handler.call_count == 2 + assert len(closers) == 2 + + # Verify both topics were registered + topics = [ + call.kwargs["topic"] + for call in mock_client.subscribe_with_handler.call_args_list + ] + assert "orders.created" in topics + assert "orders.cancelled" in topics + + +def test_register_message_handlers_ignores_undecorated_methods(): + """Test that methods without @message_router are ignored.""" + mock_client = MagicMock() + mock_client.subscribe_with_handler.return_value = MagicMock() + + class MixedHandler: + @message_router(pubsub="messagepubsub", topic="orders") + def decorated_handler(self, message: OrderCreated): + return "success" + + def regular_method(self, message: OrderCreated): + """Not decorated, should be ignored.""" + return "ignored" + + handler = MixedHandler() + closers = register_message_handlers([handler], mock_client, loop=None) + + # Should only create one subscription (for decorated method) + assert mock_client.subscribe_with_handler.call_count == 1 + assert len(closers) == 1 + + +def test_register_message_handlers_handles_multiple_targets(): + """Test registering multiple targets (functions and instances).""" + mock_client = MagicMock() + mock_client.subscribe_with_handler.return_value = MagicMock() + + @message_router(pubsub="messagepubsub", topic="orders") + def standalone_handler(message: OrderCreated): + pass + + class OrderHandler: + @message_router(pubsub="messagepubsub", topic="shipments") + def handle_shipment(self, message: ShipmentCreated): + pass + + handler_instance = OrderHandler() + closers = register_message_handlers( + [standalone_handler, handler_instance], mock_client, loop=None + ) + + # Should create two subscriptions + assert mock_client.subscribe_with_handler.call_count == 2 + assert len(closers) == 2 + + +def test_register_message_handlers_returns_closers(): + """Test that closer functions are returned for each subscription.""" + mock_client = MagicMock() + mock_client.subscribe_with_handler.return_value = MagicMock() + + @message_router(pubsub="messagepubsub", topic="orders.created") + def handle_created(message: OrderCreated): + pass + + @message_router(pubsub="messagepubsub", topic="orders.cancelled") + def handle_cancelled(message: OrderCancelled): + pass + + closers = register_message_handlers( + [handle_created, handle_cancelled], mock_client, loop=None + ) + + # Should return two closers + assert len(closers) == 2 + assert all(callable(closer) for closer in closers)