Skip to content

Commit c53ea26

Browse files
Support sub channel identification from Activities (#150)
* Adding skeleton for sub-channel handling * ChannelId test setup * Defining serialized and validators for Activity and ChannelId * Passing ChannelId tests * Fixing imports * Reorganizing Activity tests * Test adjustment * Adjusting setter/getter for _channel_id in Activity * Fixing test cases and finalizing Activity serializer * Tweaks to docstrings * Addressing review comments * Addressing edge case * Completed fix for serializing a None * Refactoring to make ChannelId a subclass of str * Updated implementation details * Removing Self import from typing * Addressing PR comments * Addressing PR review and making entities subclass from Entity * Raising exceptions when ProductInfo and channel_id.sub_channel conflict * Adding copyright comment * Reverting strenum usage * Removing unnecessary str conversion and unnecessary comments
1 parent 7cde971 commit c53ea26

File tree

22 files changed

+855
-34
lines changed

22 files changed

+855
-34
lines changed

libraries/microsoft-agents-activity/microsoft_agents/activity/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
14
from .agents_model import AgentsModel
25
from .action_types import ActionTypes
36
from .activity import Activity
@@ -17,6 +20,8 @@
1720
from .card_image import CardImage
1821
from .channels import Channels
1922
from .channel_account import ChannelAccount
23+
from ._channel_id_field_mixin import _ChannelIdFieldMixin
24+
from .channel_id import ChannelId
2025
from .conversation_account import ConversationAccount
2126
from .conversation_members import ConversationMembers
2227
from .conversation_parameters import ConversationParameters
@@ -26,6 +31,7 @@
2631
from .expected_replies import ExpectedReplies
2732
from .entity import (
2833
Entity,
34+
EntityTypes,
2935
AIEntity,
3036
ClientCitation,
3137
ClientCitationAppearance,
@@ -36,6 +42,7 @@
3642
SensitivityPattern,
3743
GeoCoordinates,
3844
Place,
45+
ProductInfo,
3946
Thing,
4047
)
4148
from .error import Error
@@ -115,6 +122,8 @@
115122
"CardImage",
116123
"Channels",
117124
"ChannelAccount",
125+
"ChannelId",
126+
"_ChannelIdFieldMixin",
118127
"ConversationAccount",
119128
"ConversationMembers",
120129
"ConversationParameters",
@@ -145,6 +154,7 @@
145154
"OAuthCard",
146155
"PagedMembersResult",
147156
"Place",
157+
"ProductInfo",
148158
"ReceiptCard",
149159
"ReceiptItem",
150160
"ResourceResponse",
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
from __future__ import annotations
5+
6+
import logging
7+
from typing import Optional, Any
8+
9+
from pydantic import (
10+
ModelWrapValidatorHandler,
11+
SerializerFunctionWrapHandler,
12+
computed_field,
13+
model_validator,
14+
model_serializer,
15+
)
16+
17+
from .channel_id import ChannelId
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
# can be generalized in the future, if needed
23+
class _ChannelIdFieldMixin:
24+
"""A mixin to add a computed field channel_id of type ChannelId to a Pydantic model."""
25+
26+
_channel_id: Optional[ChannelId] = None
27+
28+
# required to define the setter below
29+
@computed_field(return_type=Optional[ChannelId], alias="channelId")
30+
@property
31+
def channel_id(self) -> Optional[ChannelId]:
32+
"""Gets the _channel_id field"""
33+
return self._channel_id
34+
35+
# necessary for backward compatibility
36+
# previously, channel_id was directly assigned with strings
37+
@channel_id.setter
38+
def channel_id(self, value: Any):
39+
"""Sets the channel_id after validating it as a ChannelId model."""
40+
if isinstance(value, ChannelId):
41+
self._channel_id = value
42+
elif isinstance(value, str):
43+
self._channel_id = ChannelId(value)
44+
else:
45+
raise ValueError(
46+
f"Invalid type for channel_id: {type(value)}. "
47+
"Expected ChannelId or str."
48+
)
49+
50+
def _set_validated_channel_id(self, data: Any) -> None:
51+
"""Sets the channel_id after validating it as a ChannelId model."""
52+
if "channelId" in data:
53+
self.channel_id = data["channelId"]
54+
elif "channel_id" in data:
55+
self.channel_id = data["channel_id"]
56+
57+
@model_validator(mode="wrap")
58+
@classmethod
59+
def _validate_channel_id(
60+
cls, data: Any, handler: ModelWrapValidatorHandler
61+
) -> _ChannelIdFieldMixin:
62+
"""Validate the _channel_id field after model initialization.
63+
64+
:return: The model instance itself.
65+
"""
66+
try:
67+
model = handler(data)
68+
model._set_validated_channel_id(data)
69+
return model
70+
except Exception:
71+
logging.error("Model %s failed to validate with data %s", cls, data)
72+
raise
73+
74+
def _remove_serialized_unset_channel_id(
75+
self, serialized: dict[str, object]
76+
) -> None:
77+
"""Remove the _channel_id field if it is not set."""
78+
if not self._channel_id:
79+
if "channelId" in serialized:
80+
del serialized["channelId"]
81+
elif "channel_id" in serialized:
82+
del serialized["channel_id"]
83+
84+
@model_serializer(mode="wrap")
85+
def _serialize_channel_id(
86+
self, handler: SerializerFunctionWrapHandler
87+
) -> dict[str, object]:
88+
"""Serialize the model using Pydantic's standard serialization.
89+
90+
:param handler: The serialization handler provided by Pydantic.
91+
:return: A dictionary representing the serialized model.
92+
"""
93+
serialized = handler(self)
94+
if self: # serialization can be called with None
95+
self._remove_serialized_unset_channel_id(serialized)
96+
return serialized

libraries/microsoft-agents-activity/microsoft_agents/activity/activity.py

Lines changed: 120 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
33

4+
from __future__ import annotations
5+
6+
import logging
47
from copy import copy
58
from datetime import datetime, timezone
6-
from typing import Optional
7-
from pydantic import Field, SerializeAsAny
9+
from typing import Optional, Any
10+
11+
from pydantic import (
12+
Field,
13+
SerializeAsAny,
14+
model_serializer,
15+
model_validator,
16+
SerializerFunctionWrapHandler,
17+
ModelWrapValidatorHandler,
18+
computed_field,
19+
ValidationError,
20+
)
21+
822
from .activity_types import ActivityTypes
923
from .channel_account import ChannelAccount
1024
from .conversation_account import ConversationAccount
@@ -14,22 +28,28 @@
1428
from .attachment import Attachment
1529
from .entity import (
1630
Entity,
31+
EntityTypes,
1732
Mention,
1833
AIEntity,
1934
ClientCitation,
35+
ProductInfo,
2036
SensitivityUsageInfo,
2137
)
2238
from .conversation_reference import ConversationReference
2339
from .text_highlight import TextHighlight
2440
from .semantic_action import SemanticAction
2541
from .agents_model import AgentsModel
2642
from .role_types import RoleTypes
43+
from ._channel_id_field_mixin import _ChannelIdFieldMixin
44+
from .channel_id import ChannelId
2745
from ._model_utils import pick_model, SkipNone
2846
from ._type_aliases import NonEmptyString
2947

48+
logger = logging.getLogger(__name__)
49+
3050

3151
# TODO: A2A Agent 2 is responding with None as id, had to mark it as optional (investigate)
32-
class Activity(AgentsModel):
52+
class Activity(AgentsModel, _ChannelIdFieldMixin):
3353
"""An Activity is the basic communication type for the protocol.
3454
3555
:param type: Contains the activity type. Possible values include:
@@ -50,8 +70,8 @@ class Activity(AgentsModel):
5070
:type local_timezone: str
5171
:param service_url: Contains the URL that specifies the channel's service endpoint. Set by the channel.
5272
:type service_url: str
53-
:param channel_id: Contains an ID that uniquely identifies the channel. Set by the channel.
54-
:type channel_id: str
73+
:param channel_id: Contains an ID that uniquely identifies the channel (and possibly the sub-channel). Set by the channel.
74+
:type channel_id: ~microsoft_agents.activity.ChannelId
5575
:param from_property: Identifies the sender of the message.
5676
:type from_property: ~microsoft_agents.activity.ChannelAccount
5777
:param conversation: Identifies the conversation to which the activity belongs.
@@ -136,7 +156,6 @@ class Activity(AgentsModel):
136156
local_timestamp: datetime = None
137157
local_timezone: NonEmptyString = None
138158
service_url: NonEmptyString = None
139-
channel_id: NonEmptyString = None
140159
from_property: ChannelAccount = Field(None, alias="from")
141160
conversation: ConversationAccount = None
142161
recipient: ChannelAccount = None
@@ -173,6 +192,92 @@ class Activity(AgentsModel):
173192
semantic_action: SemanticAction = None
174193
caller_id: NonEmptyString = None
175194

195+
@model_validator(mode="wrap")
196+
@classmethod
197+
def _validate_channel_id(
198+
cls, data: Any, handler: ModelWrapValidatorHandler[Activity]
199+
) -> Activity:
200+
"""Validate the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity.
201+
202+
:param data: The input data to validate.
203+
:param handler: The validation handler provided by Pydantic.
204+
:return: The validated Activity instance.
205+
"""
206+
try:
207+
# run Pydantic's standard validation first
208+
activity = handler(data)
209+
210+
# needed to assign to a computed field
211+
# needed because we override the mixin validator
212+
activity._set_validated_channel_id(data)
213+
214+
# sync sub_channel with productInfo entity
215+
product_info = activity.get_product_info_entity()
216+
if product_info and activity.channel_id:
217+
if (
218+
activity.channel_id.sub_channel
219+
and activity.channel_id.sub_channel != product_info.id
220+
):
221+
raise Exception(
222+
"Conflict between channel_id.sub_channel and productInfo entity"
223+
)
224+
activity.channel_id = ChannelId(
225+
channel=activity.channel_id.channel,
226+
sub_channel=product_info.id,
227+
)
228+
229+
return activity
230+
except ValidationError as exc:
231+
logger.error("Validation error for Activity: %s", exc, exc_info=True)
232+
raise
233+
234+
@model_serializer(mode="wrap")
235+
def _serialize_sub_channel_data(
236+
self, handler: SerializerFunctionWrapHandler
237+
) -> dict[str, object]:
238+
"""Serialize the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity.
239+
240+
:param handler: The serialization handler provided by Pydantic.
241+
:return: A dictionary representing the serialized Activity.
242+
"""
243+
244+
# run Pydantic's standard serialization first
245+
serialized = handler(self)
246+
if not self: # serialization can be called with None
247+
return serialized
248+
249+
# find the ProductInfo entity
250+
product_info = None
251+
for i, entity in enumerate(serialized.get("entities") or []):
252+
if entity.get("type", "") == EntityTypes.PRODUCT_INFO:
253+
product_info = entity
254+
break
255+
256+
# maintain consistency between ProductInfo entity and sub channel
257+
if self.channel_id and self.channel_id.sub_channel:
258+
if product_info and product_info.get("id") != self.channel_id.sub_channel:
259+
raise Exception(
260+
"Conflict between channel_id.sub_channel and productInfo entity"
261+
)
262+
elif not product_info:
263+
if not serialized.get("entities"):
264+
serialized["entities"] = []
265+
serialized["entities"].append(
266+
{
267+
"type": EntityTypes.PRODUCT_INFO,
268+
"id": self.channel_id.sub_channel,
269+
}
270+
)
271+
elif product_info: # remove productInfo entity if sub_channel is not set
272+
del serialized["entities"][i]
273+
if not serialized["entities"]: # after removal above, list may be empty
274+
del serialized["entities"]
275+
276+
# necessary due to computed_field serialization
277+
self._remove_serialized_unset_channel_id(serialized)
278+
279+
return serialized
280+
176281
def apply_conversation_reference(
177282
self, reference: ConversationReference, is_incoming: bool = False
178283
):
@@ -531,6 +636,14 @@ def get_conversation_reference(self) -> ConversationReference:
531636
service_url=self.service_url,
532637
)
533638

639+
def get_product_info_entity(self) -> Optional[ProductInfo]:
640+
if not self.entities:
641+
return None
642+
target = EntityTypes.PRODUCT_INFO.lower()
643+
# validated entities can be Entity, and that prevents us from
644+
# making assumptions about the casing of the 'type' attribute
645+
return next(filter(lambda e: e.type.lower() == target, self.entities), None)
646+
534647
def get_mentions(self) -> list[Mention]:
535648
"""
536649
Resolves the mentions from the entities of this activity.
@@ -543,7 +656,7 @@ def get_mentions(self) -> list[Mention]:
543656
"""
544657
if not self.entities:
545658
return []
546-
return [x for x in self.entities if x.type.lower() == "mention"]
659+
return [x for x in self.entities if x.type.lower() == EntityTypes.MENTION]
547660

548661
def get_reply_conversation_reference(
549662
self, reply: ResourceResponse

0 commit comments

Comments
 (0)