Skip to content

Commit d6fd666

Browse files
committed
Added new message router utils
Signed-off-by: Roberto Rodriguez <[email protected]>
1 parent 1b162ca commit d6fd666

File tree

1 file changed

+227
-0
lines changed

1 file changed

+227
-0
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import logging
5+
import types
6+
from dataclasses import is_dataclass
7+
from types import NoneType
8+
from typing import Any, Optional, Tuple, Type, Union, get_args, get_origin
9+
10+
from dapr.common.pubsub.subscription import SubscriptionMessage
11+
12+
from dapr_agents.types.message import EventMessageMetadata
13+
from dapr_agents.workflow.utils.core import is_pydantic_model, is_supported_model
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
def extract_message_models(type_hint: Any) -> list[type]:
19+
"""Normalize a message type hint into a concrete list of classes.
20+
21+
Supports:
22+
- Single class: `MyMessage` → `[MyMessage]`
23+
- Union: `Union[Foo, Bar]` or `Foo | Bar` → `[Foo, Bar]`
24+
- Optional: `Optional[Foo]` (i.e., `Union[Foo, None]`) → `[Foo]`
25+
26+
Notes:
27+
- Forward refs should be resolved by the caller (e.g., via `typing.get_type_hints`).
28+
- Non-class entries (e.g., `None`, `typing.Any`) are filtered out.
29+
- Returns an empty list when the hint isn't a usable class or union of classes.
30+
"""
31+
if type_hint is None:
32+
return []
33+
34+
origin = get_origin(type_hint)
35+
if origin in (Union, types.UnionType): # handle both `Union[...]` and `A | B`
36+
return [
37+
t for t in get_args(type_hint)
38+
if t is not NoneType and isinstance(t, type)
39+
]
40+
41+
return [type_hint] if isinstance(type_hint, type) else []
42+
43+
44+
def _maybe_json_loads(payload: Any, content_type: Optional[str]) -> Any:
45+
"""
46+
Best-effort JSON parsing based on content type and payload shape.
47+
48+
- If payload is `dict`/`list` → return as-is.
49+
- If bytes/str and content-type hints JSON (or text looks like JSON) → parse to Python.
50+
- Otherwise → return the original payload.
51+
52+
This helper is intentionally forgiving; callers should validate downstream.
53+
"""
54+
try:
55+
if isinstance(payload, (dict, list)):
56+
return payload
57+
58+
ct = (content_type or "").lower()
59+
looks_json = "json" in ct
60+
61+
if isinstance(payload, bytes):
62+
text = payload.decode("utf-8", errors="strict")
63+
if looks_json or (text and text[0] in "{["):
64+
return json.loads(text)
65+
return text
66+
67+
if isinstance(payload, str):
68+
if looks_json or (payload and payload[0] in "{["):
69+
return json.loads(payload)
70+
return payload
71+
72+
return payload
73+
except Exception:
74+
logger.debug("JSON parsing failed; returning raw payload", exc_info=True)
75+
return payload
76+
77+
78+
def extract_cloudevent_data(
79+
message: Union[SubscriptionMessage, dict, bytes, str],
80+
) -> Tuple[dict, dict]:
81+
"""
82+
Extract CloudEvent metadata and payload (attempting JSON parsing when appropriate).
83+
84+
Accepts:
85+
- `SubscriptionMessage` (Dapr SDK)
86+
- `dict` (raw CloudEvent envelope)
87+
- `bytes`/`str` (data-only; metadata is synthesized)
88+
89+
Returns:
90+
(event_data, metadata) as dictionaries. `event_data` may be non-dict JSON
91+
(e.g., list) if the payload is an array; callers expecting dicts should handle it.
92+
93+
Raises:
94+
ValueError: For unsupported `message` types.
95+
"""
96+
if isinstance(message, SubscriptionMessage):
97+
content_type = message.data_content_type()
98+
raw = message.data()
99+
event_data = _maybe_json_loads(raw, content_type)
100+
metadata = EventMessageMetadata(
101+
id=message.id(),
102+
datacontenttype=content_type,
103+
pubsubname=message.pubsub_name(),
104+
source=message.source(),
105+
specversion=message.spec_version(),
106+
time=None, # not always populated by SDK
107+
topic=message.topic(),
108+
traceid=None,
109+
traceparent=None,
110+
type=message.type(),
111+
tracestate=None,
112+
headers=message.extensions(),
113+
).model_dump()
114+
115+
elif isinstance(message, dict):
116+
content_type = message.get("datacontenttype")
117+
raw = message.get("data", {})
118+
event_data = _maybe_json_loads(raw, content_type)
119+
metadata = EventMessageMetadata(
120+
id=message.get("id"),
121+
datacontenttype=content_type,
122+
pubsubname=message.get("pubsubname"),
123+
source=message.get("source"),
124+
specversion=message.get("specversion"),
125+
time=message.get("time"),
126+
topic=message.get("topic"),
127+
traceid=message.get("traceid"),
128+
traceparent=message.get("traceparent"),
129+
type=message.get("type"),
130+
tracestate=message.get("tracestate"),
131+
headers=message.get("extensions", {}),
132+
).model_dump()
133+
134+
elif isinstance(message, (bytes, str)):
135+
# No CloudEvent envelope; treat payload as data-only and synthesize minimal metadata.
136+
content_type = "application/json"
137+
event_data = _maybe_json_loads(message, content_type)
138+
metadata = EventMessageMetadata(
139+
id=None,
140+
datacontenttype=content_type,
141+
pubsubname=None,
142+
source=None,
143+
specversion=None,
144+
time=None,
145+
topic=None,
146+
traceid=None,
147+
traceparent=None,
148+
type=None,
149+
tracestate=None,
150+
headers={},
151+
).model_dump()
152+
153+
else:
154+
raise ValueError(f"Unexpected message type: {type(message)!r}")
155+
156+
if not isinstance(event_data, dict):
157+
logger.debug("Event data is not a dict (type=%s); value=%r", type(event_data), event_data)
158+
159+
return event_data, metadata
160+
161+
162+
def validate_message_model(model: Type[Any], event_data: dict) -> Any:
163+
"""
164+
Validate and coerce `event_data` into `model`.
165+
166+
Supports:
167+
- dict: returns `event_data` unchanged
168+
- dataclass: constructs the dataclass
169+
- Pydantic v2 model: uses `model_validate`
170+
171+
Raises:
172+
TypeError: If the model is not a supported kind.
173+
ValueError: If validation/construction fails.
174+
"""
175+
if not is_supported_model(model):
176+
raise TypeError(f"Unsupported model type: {model!r}")
177+
178+
try:
179+
logger.info(f"Validating payload with model '{model.__name__}'...")
180+
181+
if model is dict:
182+
return event_data
183+
if is_dataclass(model):
184+
return model(**event_data)
185+
if is_pydantic_model(model):
186+
return model.model_validate(event_data)
187+
raise TypeError(f"Unsupported model type: {model!r}")
188+
except Exception as e:
189+
logger.error(f"Message validation failed for model '{model.__name__}': {e}")
190+
raise ValueError(f"Message validation failed: {e}")
191+
192+
193+
def parse_cloudevent(
194+
message: Union[SubscriptionMessage, dict, bytes, str],
195+
model: Optional[Type[Any]] = None,
196+
) -> Tuple[Any, dict]:
197+
"""
198+
Parse a CloudEvent-like input and validate its payload against ``model``.
199+
200+
Args:
201+
message (Union[SubscriptionMessage, dict, bytes, str]): Incoming message; can be a Dapr ``SubscriptionMessage``, a raw
202+
CloudEvent ``dict``, or bare ``bytes``/``str`` payloads.
203+
model (Optional[Type[Any]]): Schema for payload validation (required).
204+
205+
Returns:
206+
Tuple[Any, dict]: A tuple containing the validated message and its metadata.
207+
208+
Raises:
209+
ValueError: If no model is provided or validation fails.
210+
"""
211+
try:
212+
event_data, metadata = extract_cloudevent_data(message)
213+
214+
if model is None:
215+
raise ValueError("Message validation failed: No model provided.")
216+
217+
validated_message = validate_message_model(model, event_data)
218+
219+
logger.info("Message successfully parsed and validated")
220+
logger.debug(f"Data: {validated_message}")
221+
logger.debug(f"metadata: {metadata}")
222+
223+
return validated_message, metadata
224+
225+
except Exception as e:
226+
logger.error(f"Failed to parse CloudEvent: {e}", exc_info=True)
227+
raise ValueError(f"Invalid CloudEvent: {str(e)}")

0 commit comments

Comments
 (0)