11# Copyright (c) Microsoft Corporation. All rights reserved.
22# Licensed under the MIT License.
33
4+ from __future__ import annotations
5+
6+ import logging
47from copy import copy
58from datetime import datetime , timezone
6- from typing import Optional
7- from pydantic import Field , SerializeAsAny
9+ from typing import Optional , Any
10+
11+ from pydantic import (
12+ Field ,
13+ SerializeAsAny ,
14+ model_serializer ,
15+ model_validator ,
16+ SerializerFunctionWrapHandler ,
17+ ModelWrapValidatorHandler ,
18+ computed_field ,
19+ ValidationError ,
20+ )
21+
822from .activity_types import ActivityTypes
923from .channel_account import ChannelAccount
1024from .conversation_account import ConversationAccount
1428from .attachment import Attachment
1529from .entity import (
1630 Entity ,
31+ EntityTypes ,
1732 Mention ,
1833 AIEntity ,
1934 ClientCitation ,
35+ ProductInfo ,
2036 SensitivityUsageInfo ,
2137)
2238from .conversation_reference import ConversationReference
2339from .text_highlight import TextHighlight
2440from .semantic_action import SemanticAction
2541from .agents_model import AgentsModel
2642from .role_types import RoleTypes
43+ from ._channel_id_field_mixin import _ChannelIdFieldMixin
44+ from .channel_id import ChannelId
2745from ._model_utils import pick_model , SkipNone
2846from ._type_aliases import NonEmptyString
2947
48+ logger = logging .getLogger (__name__ )
49+
3050
3151# TODO: A2A Agent 2 is responding with None as id, had to mark it as optional (investigate)
32- class Activity (AgentsModel ):
52+ class Activity (AgentsModel , _ChannelIdFieldMixin ):
3353 """An Activity is the basic communication type for the protocol.
3454
3555 :param type: Contains the activity type. Possible values include:
@@ -50,8 +70,8 @@ class Activity(AgentsModel):
5070 :type local_timezone: str
5171 :param service_url: Contains the URL that specifies the channel's service endpoint. Set by the channel.
5272 :type service_url: str
53- :param channel_id: Contains an ID that uniquely identifies the channel. Set by the channel.
54- :type channel_id: str
73+ :param channel_id: Contains an ID that uniquely identifies the channel (and possibly the sub-channel) . Set by the channel.
74+ :type channel_id: ~microsoft_agents.activity.ChannelId
5575 :param from_property: Identifies the sender of the message.
5676 :type from_property: ~microsoft_agents.activity.ChannelAccount
5777 :param conversation: Identifies the conversation to which the activity belongs.
@@ -136,7 +156,6 @@ class Activity(AgentsModel):
136156 local_timestamp : datetime = None
137157 local_timezone : NonEmptyString = None
138158 service_url : NonEmptyString = None
139- channel_id : NonEmptyString = None
140159 from_property : ChannelAccount = Field (None , alias = "from" )
141160 conversation : ConversationAccount = None
142161 recipient : ChannelAccount = None
@@ -173,6 +192,92 @@ class Activity(AgentsModel):
173192 semantic_action : SemanticAction = None
174193 caller_id : NonEmptyString = None
175194
195+ @model_validator (mode = "wrap" )
196+ @classmethod
197+ def _validate_channel_id (
198+ cls , data : Any , handler : ModelWrapValidatorHandler [Activity ]
199+ ) -> Activity :
200+ """Validate the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity.
201+
202+ :param data: The input data to validate.
203+ :param handler: The validation handler provided by Pydantic.
204+ :return: The validated Activity instance.
205+ """
206+ try :
207+ # run Pydantic's standard validation first
208+ activity = handler (data )
209+
210+ # needed to assign to a computed field
211+ # needed because we override the mixin validator
212+ activity ._set_validated_channel_id (data )
213+
214+ # sync sub_channel with productInfo entity
215+ product_info = activity .get_product_info_entity ()
216+ if product_info and activity .channel_id :
217+ if (
218+ activity .channel_id .sub_channel
219+ and activity .channel_id .sub_channel != product_info .id
220+ ):
221+ raise Exception (
222+ "Conflict between channel_id.sub_channel and productInfo entity"
223+ )
224+ activity .channel_id = ChannelId (
225+ channel = activity .channel_id .channel ,
226+ sub_channel = product_info .id ,
227+ )
228+
229+ return activity
230+ except ValidationError as exc :
231+ logger .error ("Validation error for Activity: %s" , exc , exc_info = True )
232+ raise
233+
234+ @model_serializer (mode = "wrap" )
235+ def _serialize_sub_channel_data (
236+ self , handler : SerializerFunctionWrapHandler
237+ ) -> dict [str , object ]:
238+ """Serialize the Activity, ensuring consistency between channel_id.sub_channel and productInfo entity.
239+
240+ :param handler: The serialization handler provided by Pydantic.
241+ :return: A dictionary representing the serialized Activity.
242+ """
243+
244+ # run Pydantic's standard serialization first
245+ serialized = handler (self )
246+ if not self : # serialization can be called with None
247+ return serialized
248+
249+ # find the ProductInfo entity
250+ product_info = None
251+ for i , entity in enumerate (serialized .get ("entities" ) or []):
252+ if entity .get ("type" , "" ) == EntityTypes .PRODUCT_INFO :
253+ product_info = entity
254+ break
255+
256+ # maintain consistency between ProductInfo entity and sub channel
257+ if self .channel_id and self .channel_id .sub_channel :
258+ if product_info and product_info .get ("id" ) != self .channel_id .sub_channel :
259+ raise Exception (
260+ "Conflict between channel_id.sub_channel and productInfo entity"
261+ )
262+ elif not product_info :
263+ if not serialized .get ("entities" ):
264+ serialized ["entities" ] = []
265+ serialized ["entities" ].append (
266+ {
267+ "type" : EntityTypes .PRODUCT_INFO ,
268+ "id" : self .channel_id .sub_channel ,
269+ }
270+ )
271+ elif product_info : # remove productInfo entity if sub_channel is not set
272+ del serialized ["entities" ][i ]
273+ if not serialized ["entities" ]: # after removal above, list may be empty
274+ del serialized ["entities" ]
275+
276+ # necessary due to computed_field serialization
277+ self ._remove_serialized_unset_channel_id (serialized )
278+
279+ return serialized
280+
176281 def apply_conversation_reference (
177282 self , reference : ConversationReference , is_incoming : bool = False
178283 ):
@@ -531,6 +636,14 @@ def get_conversation_reference(self) -> ConversationReference:
531636 service_url = self .service_url ,
532637 )
533638
639+ def get_product_info_entity (self ) -> Optional [ProductInfo ]:
640+ if not self .entities :
641+ return None
642+ target = EntityTypes .PRODUCT_INFO .lower ()
643+ # validated entities can be Entity, and that prevents us from
644+ # making assumptions about the casing of the 'type' attribute
645+ return next (filter (lambda e : e .type .lower () == target , self .entities ), None )
646+
534647 def get_mentions (self ) -> list [Mention ]:
535648 """
536649 Resolves the mentions from the entities of this activity.
@@ -543,7 +656,7 @@ def get_mentions(self) -> list[Mention]:
543656 """
544657 if not self .entities :
545658 return []
546- return [x for x in self .entities if x .type .lower () == "mention" ]
659+ return [x for x in self .entities if x .type .lower () == EntityTypes . MENTION ]
547660
548661 def get_reply_conversation_reference (
549662 self , reply : ResourceResponse
0 commit comments