Skip to content

Commit 160004e

Browse files
committed
Added util to subscribe handlers decorated with new message_router
Signed-off-by: Roberto Rodriguez <[email protected]>
1 parent d6fd666 commit 160004e

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import inspect
5+
import logging
6+
from typing import Any, Callable, Iterable, List, Optional, Type
7+
8+
from dapr.clients import DaprClient
9+
from dapr.clients.grpc._response import TopicEventResponse
10+
from dapr.common.pubsub.subscription import SubscriptionMessage
11+
12+
from dapr_agents.workflow.utils.messaging import (
13+
extract_cloudevent_data,
14+
validate_message_model,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def register_message_handlers(
21+
targets: Iterable[Any],
22+
dapr_client: DaprClient,
23+
*,
24+
loop: Optional[asyncio.AbstractEventLoop] = None,
25+
) -> List[Callable[[], None]]:
26+
"""Discover and subscribe handlers decorated with `@message_router`.
27+
28+
Scans each target:
29+
- If the target itself is a decorated function (has `_message_router_data`), it is registered.
30+
- If the target is an object, all its attributes are scanned for decorated callables.
31+
32+
Subscriptions use Dapr's streaming API (`subscribe_with_handler`) which invokes your handler
33+
on a background thread. This function returns a list of "closer" callables. Invoking a closer
34+
will unsubscribe the corresponding handler.
35+
36+
Args:
37+
targets: Functions and/or instances to inspect for `_message_router_data`.
38+
dapr_client: Active Dapr client used to create subscriptions.
39+
loop: Event loop to await async handlers. If omitted, uses the running loop
40+
or falls back to `asyncio.get_event_loop()`.
41+
42+
Returns:
43+
A list of callables. Each callable, when invoked, closes the associated subscription.
44+
"""
45+
# Resolve loop strategy once up front.
46+
if loop is None:
47+
try:
48+
loop = asyncio.get_running_loop()
49+
except RuntimeError:
50+
loop = asyncio.get_event_loop()
51+
52+
closers: List[Callable[[], None]] = []
53+
54+
def _iter_handlers(obj: Any):
55+
"""Yield (owner, fn) pairs for decorated handlers on `obj`.
56+
57+
If `obj` is itself a decorated function, yield (None, obj).
58+
If `obj` is an instance, scan its attributes for decorated callables.
59+
"""
60+
meta = getattr(obj, "_message_router_data", None)
61+
if callable(obj) and meta:
62+
yield None, obj
63+
return
64+
65+
for name in dir(obj):
66+
fn = getattr(obj, name)
67+
if callable(fn) and getattr(fn, "_message_router_data", None):
68+
yield obj, fn
69+
70+
for target in targets:
71+
for owner, handler in _iter_handlers(target):
72+
meta = getattr(handler, "_message_router_data")
73+
schemas: List[Type[Any]] = meta.get("message_schemas") or []
74+
75+
# Bind method to instance if needed (descriptor protocol).
76+
bound = handler if owner is None else handler.__get__(owner, owner.__class__)
77+
78+
async def _invoke(
79+
bound_handler: Callable[..., Any],
80+
parsed: Any,
81+
) -> TopicEventResponse:
82+
"""Invoke the user handler (sync or async) and normalize the result."""
83+
result = bound_handler(parsed)
84+
if inspect.iscoroutine(result):
85+
result = await result
86+
if isinstance(result, TopicEventResponse):
87+
return result
88+
# Treat any truthy/None return as success unless user explicitly returns a response.
89+
return TopicEventResponse("success")
90+
91+
def _make_handler(
92+
bound_handler: Callable[..., Any],
93+
) -> Callable[[SubscriptionMessage], TopicEventResponse]:
94+
"""Create a Dapr-compatible handler for a single decorated function."""
95+
def handler_fn(message: SubscriptionMessage) -> TopicEventResponse:
96+
try:
97+
# 1) Extract payload + CloudEvent metadata (bytes/str/dict are also supported by the extractor)
98+
event_data, metadata = extract_cloudevent_data(message)
99+
100+
# 2) Validate against the first matching schema (or dict as fallback)
101+
parsed = None
102+
for model in (schemas or [dict]):
103+
try:
104+
parsed = validate_message_model(model, event_data)
105+
break
106+
except Exception:
107+
# Try the next schema; log at debug for signal without noise.
108+
logger.debug("Schema %r did not match payload; trying next.", model, exc_info=True)
109+
continue
110+
111+
if parsed is None:
112+
# Permanent schema mismatch → drop (DLQ if configured by Dapr)
113+
logger.warning(
114+
"No matching schema for message on topic %r; dropping. Raw payload: %r",
115+
meta["topic"],
116+
event_data,
117+
)
118+
return TopicEventResponse("drop")
119+
120+
# 3) Attach CE metadata for downstream consumers
121+
if isinstance(parsed, dict):
122+
parsed["_message_metadata"] = metadata
123+
else:
124+
setattr(parsed, "_message_metadata", metadata)
125+
126+
# 4) Bridge worker thread → event loop
127+
if loop and loop.is_running():
128+
fut = asyncio.run_coroutine_threadsafe(_invoke(bound_handler, parsed), loop)
129+
return fut.result()
130+
return asyncio.run(_invoke(bound_handler, parsed))
131+
132+
except Exception:
133+
# Transient failure (I/O, handler crash, etc.) → retry
134+
logger.exception("Message handler error; requesting retry.")
135+
return TopicEventResponse("retry")
136+
137+
return handler_fn
138+
139+
close_fn = dapr_client.subscribe_with_handler(
140+
pubsub_name=meta["pubsub"],
141+
topic=meta["topic"],
142+
handler_fn=_make_handler(bound),
143+
dead_letter_topic=meta.get("dead_letter_topic"),
144+
)
145+
closers.append(close_fn)
146+
147+
return closers

0 commit comments

Comments
 (0)