Skip to content
23 changes: 21 additions & 2 deletions dapr_agents/workflow/decorators/messaging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import logging
import warnings
from copy import deepcopy
from typing import Any, Callable, Optional, get_type_hints

from dapr_agents.workflow.utils.core import is_valid_routable_model
from dapr_agents.workflow.utils.messaging import extract_message_models

logger = logging.getLogger(__name__)

_MESSAGE_ROUTER_DEPRECATION_MESSAGE = (
"@message_router (legacy version from dapr_agents.workflow.decorators.messaging) "
"is deprecated and will be removed in a future release. "
"Please migrate to the updated decorator in "
"`dapr_agents.workflow.decorators.routers`, which supports "
"Union types, forward references, and explicit Dapr workflow integration."
)


def message_router(
func: Optional[Callable[..., Any]] = None,
Expand All @@ -16,7 +26,8 @@ def message_router(
broadcast: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator for registering message handlers by inspecting type hints on the 'message' argument.
[DEPRECATED] Legacy decorator for registering message handlers by inspecting type hints
on the 'message' argument.

This decorator:
- Extracts the expected message model type from function annotations.
Expand All @@ -36,6 +47,12 @@ def message_router(
"""

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
warnings.warn(
_MESSAGE_ROUTER_DEPRECATION_MESSAGE,
DeprecationWarning,
stacklevel=2,
)

is_workflow = hasattr(f, "_is_workflow")
workflow_name = getattr(f, "_workflow_name", None)

Expand All @@ -56,7 +73,9 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
)

logger.debug(
f"@message_router: '{f.__name__}' => models {[m.__name__ for m in message_models]}"
"@message_router (legacy): '%s' => models %s",
f.__name__,
[m.__name__ for m in message_models],
)

# Attach metadata for later registration
Expand Down
113 changes: 113 additions & 0 deletions dapr_agents/workflow/decorators/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

import inspect
import logging
from copy import deepcopy
from typing import (
Any,
Callable,
Optional,
get_type_hints,
)

from dapr_agents.workflow.utils.core import is_supported_model
from dapr_agents.workflow.utils.routers import extract_message_models

logger = logging.getLogger(__name__)


def message_router(
func: Optional[Callable[..., Any]] = None,
*,
pubsub: Optional[str] = None,
topic: Optional[str] = None,
dead_letter_topic: Optional[str] = None,
broadcast: bool = False,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorate a message handler with routing metadata.
The handler must accept a parameter named `message`. Its type hint defines the
expected payload model(s), e.g.:
@message_router(pubsub="pubsub", topic="orders")
def on_order(message: OrderCreated): ...
@message_router(pubsub="pubsub", topic="events")
def on_event(message: Union[Foo, Bar]): ...
Args:
func: (optional) bare-decorator form support.
pubsub: Name of the Dapr pub/sub component (required when used with args).
topic: Topic name to subscribe to (required when used with args).
dead_letter_topic: Optional dead-letter topic (defaults to f"{topic}_DEAD").
broadcast: Optional flag you can use downstream for fan-out semantics.
Returns:
The original function tagged with `_message_router_data`.
"""

def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
# Validate required kwargs only when decorator is used with args
if pubsub is None or topic is None:
raise ValueError(
"`pubsub` and `topic` are required when using @message_router with arguments."
)

sig = inspect.signature(f)
if "message" not in sig.parameters:
raise ValueError(f"'{f.__name__}' must have a 'message' parameter.")

# Resolve forward refs under PEP 563 / future annotations
try:
hints = get_type_hints(f, globalns=f.__globals__)
except Exception:
logger.debug(
"Failed to fully resolve type hints for %s", f.__name__, exc_info=True
)
hints = getattr(f, "__annotations__", {}) or {}

raw_hint = hints.get("message")
if raw_hint is None:
raise TypeError(
f"'{f.__name__}' must type-hint the 'message' parameter "
"(e.g., 'message: MyModel' or 'message: Union[A, B]')"
)

models = extract_message_models(raw_hint)
if not models:
raise TypeError(
f"Unsupported or unresolved message type for '{f.__name__}': {raw_hint!r}"
)

# Optional early validation of supported schema kinds
for m in models:
if not is_supported_model(m):
raise TypeError(f"Unsupported model type in '{f.__name__}': {m!r}")

data = {
"pubsub": pubsub,
"topic": topic,
"dead_letter_topic": dead_letter_topic
or (f"{topic}_DEAD" if topic else None),
"is_broadcast": broadcast,
"message_schemas": models, # list[type]
"message_types": [m.__name__ for m in models], # list[str]
}

# Attach metadata; deepcopy for defensive isolation
setattr(f, "_is_message_handler", True)
setattr(f, "_message_router_data", deepcopy(data))

logger.debug(
"@message_router: '%s' => models %s (topic=%s, pubsub=%s, broadcast=%s)",
f.__name__,
[m.__name__ for m in models],
topic,
pubsub,
broadcast,
)
return f

# Support both @message_router(...) and bare @message_router usage
return decorator if func is None else decorator(func)
156 changes: 156 additions & 0 deletions dapr_agents/workflow/utils/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from __future__ import annotations

import asyncio
import inspect
import logging
from typing import Any, Callable, Iterable, List, Optional, Type

from dapr.clients import DaprClient
from dapr.clients.grpc._response import TopicEventResponse
from dapr.common.pubsub.subscription import SubscriptionMessage

from dapr_agents.workflow.utils.messaging import (
extract_cloudevent_data,
validate_message_model,
)

logger = logging.getLogger(__name__)


def register_message_handlers(
targets: Iterable[Any],
dapr_client: DaprClient,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> List[Callable[[], None]]:
"""Discover and subscribe handlers decorated with `@message_router`.

Scans each target:
- If the target itself is a decorated function (has `_message_router_data`), it is registered.
- If the target is an object, all its attributes are scanned for decorated callables.

Subscriptions use Dapr's streaming API (`subscribe_with_handler`) which invokes your handler
on a background thread. This function returns a list of "closer" callables. Invoking a closer
will unsubscribe the corresponding handler.

Args:
targets: Functions and/or instances to inspect for `_message_router_data`.
dapr_client: Active Dapr client used to create subscriptions.
loop: Event loop to await async handlers. If omitted, uses the running loop
or falls back to `asyncio.get_event_loop()`.

Returns:
A list of callables. Each callable, when invoked, closes the associated subscription.
"""
# Resolve loop strategy once up front.
if loop is None:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.get_event_loop()

closers: List[Callable[[], None]] = []

def _iter_handlers(obj: Any):
"""Yield (owner, fn) pairs for decorated handlers on `obj`.

If `obj` is itself a decorated function, yield (None, obj).
If `obj` is an instance, scan its attributes for decorated callables.
"""
meta = getattr(obj, "_message_router_data", None)
if callable(obj) and meta:
yield None, obj
return

for name in dir(obj):
fn = getattr(obj, name)
if callable(fn) and getattr(fn, "_message_router_data", None):
yield obj, fn

for target in targets:
for owner, handler in _iter_handlers(target):
meta = getattr(handler, "_message_router_data")
schemas: List[Type[Any]] = meta.get("message_schemas") or []

# Bind method to instance if needed (descriptor protocol).
bound = (
handler if owner is None else handler.__get__(owner, owner.__class__)
)

async def _invoke(
bound_handler: Callable[..., Any],
parsed: Any,
) -> TopicEventResponse:
"""Invoke the user handler (sync or async) and normalize the result."""
result = bound_handler(parsed)
if inspect.iscoroutine(result):
result = await result
if isinstance(result, TopicEventResponse):
return result
# Treat any truthy/None return as success unless user explicitly returns a response.
return TopicEventResponse("success")

def _make_handler(
bound_handler: Callable[..., Any],
) -> Callable[[SubscriptionMessage], TopicEventResponse]:
"""Create a Dapr-compatible handler for a single decorated function."""

def handler_fn(message: SubscriptionMessage) -> TopicEventResponse:
try:
# 1) Extract payload + CloudEvent metadata (bytes/str/dict are also supported by the extractor)
event_data, metadata = extract_cloudevent_data(message)

# 2) Validate against the first matching schema (or dict as fallback)
parsed = None
for model in schemas or [dict]:
try:
parsed = validate_message_model(model, event_data)
break
except Exception:
# Try the next schema; log at debug for signal without noise.
logger.debug(
"Schema %r did not match payload; trying next.",
model,
exc_info=True,
)
continue

if parsed is None:
# Permanent schema mismatch → drop (DLQ if configured by Dapr)
logger.warning(
"No matching schema for message on topic %r; dropping. Raw payload: %r",
meta["topic"],
event_data,
)
return TopicEventResponse("drop")

# 3) Attach CE metadata for downstream consumers
if isinstance(parsed, dict):
parsed["_message_metadata"] = metadata
else:
setattr(parsed, "_message_metadata", metadata)

# 4) Bridge worker thread → event loop
if loop and loop.is_running():
fut = asyncio.run_coroutine_threadsafe(
_invoke(bound_handler, parsed), loop
)
return fut.result()
return asyncio.run(_invoke(bound_handler, parsed))

except Exception:
# Transient failure (I/O, handler crash, etc.) → retry
logger.exception("Message handler error; requesting retry.")
return TopicEventResponse("retry")

return handler_fn

close_fn = dapr_client.subscribe_with_handler(
pubsub_name=meta["pubsub"],
topic=meta["topic"],
handler_fn=_make_handler(bound),
dead_letter_topic=meta.get("dead_letter_topic"),
)
closers.append(close_fn)

return closers
Loading