11# Copyright (c) Microsoft. All rights reserved.
22
3- from collections .abc import Sequence
3+ from collections .abc import MutableMapping , Sequence
44from typing import Any , Protocol , TypeVar
55
6- from pydantic import BaseModel , ConfigDict , model_validator
7-
86from ._memory import AggregateContextProvider
7+ from ._serialization import SerializationMixin
98from ._types import ChatMessage
109from .exceptions import AgentThreadException
1110
@@ -73,7 +72,9 @@ async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
7372 ...
7473
7574 @classmethod
76- async def deserialize (cls , serialized_store_state : Any , ** kwargs : Any ) -> "ChatMessageStoreProtocol" :
75+ async def deserialize (
76+ cls , serialized_store_state : MutableMapping [str , Any ], ** kwargs : Any
77+ ) -> "ChatMessageStoreProtocol" :
7778 """Creates a new instance of the store from previously serialized state.
7879
7980 This method, together with ``serialize()`` can be used to save and load messages from a persistent store
@@ -90,7 +91,7 @@ async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatM
9091 """
9192 ...
9293
93- async def update_from_state (self , serialized_store_state : Any , ** kwargs : Any ) -> None :
94+ async def update_from_state (self , serialized_store_state : MutableMapping [ str , Any ] , ** kwargs : Any ) -> None :
9495 """Update the current ChatMessageStore instance from serialized state data.
9596
9697 Args:
@@ -101,7 +102,7 @@ async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) ->
101102 """
102103 ...
103104
104- async def serialize (self , ** kwargs : Any ) -> Any :
105+ async def serialize (self , ** kwargs : Any ) -> dict [ str , Any ] :
105106 """Serializes the current object's state.
106107
107108 This method, together with ``deserialize()`` can be used to save and load messages from a persistent store
@@ -116,40 +117,66 @@ async def serialize(self, **kwargs: Any) -> Any:
116117 ...
117118
118119
119- class ChatMessageStoreState (BaseModel ):
120+ class ChatMessageStoreState (SerializationMixin ):
120121 """State model for serializing and deserializing chat message store data.
121122
122123 Attributes:
123124 messages: List of chat messages stored in the message store.
124125 """
125126
126- messages : list [ChatMessage ]
127-
128- model_config = ConfigDict (arbitrary_types_allowed = True )
129-
127+ def __init__ (
128+ self ,
129+ messages : Sequence [ChatMessage ] | Sequence [MutableMapping [str , Any ]] | None = None ,
130+ ** kwargs : Any ,
131+ ) -> None :
132+ """Create the store state.
130133
131- class AgentThreadState ( BaseModel ) :
132- """State model for serializing and deserializing thread information .
134+ Args :
135+ messages: a list of messages or a list of the dict representation of messages .
133136
134- Attributes:
135- service_thread_id: Optional ID of the thread managed by the agent service.
136- chat_message_store_state: Optional serialized state of the chat message store.
137- """
137+ Keyword Args:
138+ **kwargs: not used for this, but might be used by subclasses.
138139
139- service_thread_id : str | None = None
140- chat_message_store_state : ChatMessageStoreState | None = None
140+ """
141+ if not messages :
142+ self .messages : list [ChatMessage ] = []
143+ if not isinstance (messages , list ):
144+ raise TypeError ("Messages should be a list" )
145+ new_messages : list [ChatMessage ] = []
146+ for msg in messages :
147+ if isinstance (msg , ChatMessage ):
148+ new_messages .append (msg )
149+ else :
150+ new_messages .append (ChatMessage .from_dict (msg ))
151+ self .messages = new_messages
152+
153+
154+ class AgentThreadState (SerializationMixin ):
155+ """State model for serializing and deserializing thread information."""
141156
142- model_config = ConfigDict (arbitrary_types_allowed = True )
157+ def __init__ (
158+ self ,
159+ * ,
160+ service_thread_id : str | None = None ,
161+ chat_message_store_state : ChatMessageStoreState | MutableMapping [str , Any ] | None = None ,
162+ ) -> None :
163+ """Create a AgentThread state.
143164
144- @model_validator (mode = "before" )
145- def validate_only_one (cls , values : dict [str , Any ]) -> dict [str , Any ]:
146- if (
147- isinstance (values , dict )
148- and values .get ("service_thread_id" ) is not None
149- and values .get ("chat_message_store_state" ) is not None
150- ):
151- raise AgentThreadException ("Only one of service_thread_id or chat_message_store_state may be set." )
152- return values
165+ Keyword Args:
166+ service_thread_id: Optional ID of the thread managed by the agent service.
167+ chat_message_store_state: Optional serialized state of the chat message store.
168+ """
169+ if service_thread_id is not None and chat_message_store_state is not None :
170+ raise AgentThreadException ("A thread cannot have both a service_thread_id and a chat_message_store." )
171+ self .service_thread_id = service_thread_id
172+ self .chat_message_store_state : ChatMessageStoreState | None = None
173+ if chat_message_store_state is not None :
174+ if isinstance (chat_message_store_state , dict ):
175+ self .chat_message_store_state = ChatMessageStoreState .from_dict (chat_message_store_state )
176+ elif isinstance (chat_message_store_state , ChatMessageStoreState ):
177+ self .chat_message_store_state = chat_message_store_state
178+ else :
179+ raise TypeError ("Could not parse ChatMessageStoreState." )
153180
154181
155182TChatMessageStore = TypeVar ("TChatMessageStore" , bound = "ChatMessageStore" )
@@ -213,7 +240,7 @@ async def list_messages(self) -> list[ChatMessage]:
213240
214241 @classmethod
215242 async def deserialize (
216- cls : type [TChatMessageStore ], serialized_store_state : Any , ** kwargs : Any
243+ cls : type [TChatMessageStore ], serialized_store_state : MutableMapping [ str , Any ] , ** kwargs : Any
217244 ) -> TChatMessageStore :
218245 """Create a new ChatMessageStore instance from serialized state data.
219246
@@ -226,12 +253,12 @@ async def deserialize(
226253 Returns:
227254 A new ChatMessageStore instance populated with messages from the serialized state.
228255 """
229- state = ChatMessageStoreState .model_validate (serialized_store_state , ** kwargs )
256+ state = ChatMessageStoreState .from_dict (serialized_store_state , ** kwargs )
230257 if state .messages :
231258 return cls (messages = state .messages )
232259 return cls ()
233260
234- async def update_from_state (self , serialized_store_state : Any , ** kwargs : Any ) -> None :
261+ async def update_from_state (self , serialized_store_state : MutableMapping [ str , Any ] , ** kwargs : Any ) -> None :
235262 """Update the current ChatMessageStore instance from serialized state data.
236263
237264 Args:
@@ -242,11 +269,11 @@ async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) ->
242269 """
243270 if not serialized_store_state :
244271 return
245- state = ChatMessageStoreState .model_validate (serialized_store_state , ** kwargs )
272+ state = ChatMessageStoreState .from_dict (serialized_store_state , ** kwargs )
246273 if state .messages :
247274 self .messages = state .messages
248275
249- async def serialize (self , ** kwargs : Any ) -> Any :
276+ async def serialize (self , ** kwargs : Any ) -> dict [ str , Any ] :
250277 """Serialize the current store state for persistence.
251278
252279 Keyword Args:
@@ -256,7 +283,7 @@ async def serialize(self, **kwargs: Any) -> Any:
256283 Serialized state data that can be used with deserialize_state.
257284 """
258285 state = ChatMessageStoreState (messages = self .messages )
259- return state .model_dump ( ** kwargs )
286+ return state .to_dict ( )
260287
261288
262289TAgentThread = TypeVar ("TAgentThread" , bound = "AgentThread" )
@@ -403,12 +430,12 @@ async def serialize(self, **kwargs: Any) -> dict[str, Any]:
403430 state = AgentThreadState (
404431 service_thread_id = self ._service_thread_id , chat_message_store_state = chat_message_store_state
405432 )
406- return state .model_dump ( )
433+ return state .to_dict ( exclude_none = False )
407434
408435 @classmethod
409436 async def deserialize (
410437 cls : type [TAgentThread ],
411- serialized_thread_state : dict [str , Any ],
438+ serialized_thread_state : MutableMapping [str , Any ],
412439 * ,
413440 message_store : ChatMessageStoreProtocol | None = None ,
414441 ** kwargs : Any ,
@@ -426,7 +453,7 @@ async def deserialize(
426453 Returns:
427454 A new AgentThread instance with properties set from the serialized state.
428455 """
429- state = AgentThreadState .model_validate (serialized_thread_state )
456+ state = AgentThreadState .from_dict (serialized_thread_state )
430457
431458 if state .service_thread_id is not None :
432459 return cls (service_thread_id = state .service_thread_id )
@@ -437,19 +464,19 @@ async def deserialize(
437464
438465 if message_store is not None :
439466 try :
440- await message_store .update_from_state (state .chat_message_store_state , ** kwargs )
467+ await message_store .add_messages (state .chat_message_store_state . messages , ** kwargs )
441468 except Exception as ex :
442469 raise AgentThreadException ("Failed to deserialize the provided message store." ) from ex
443470 return cls (message_store = message_store )
444471 try :
445- message_store = await ChatMessageStore . deserialize ( state .chat_message_store_state , ** kwargs )
472+ message_store = ChatMessageStore ( messages = state .chat_message_store_state . messages , ** kwargs )
446473 except Exception as ex :
447474 raise AgentThreadException ("Failed to deserialize the message store." ) from ex
448475 return cls (message_store = message_store )
449476
450477 async def update_from_thread_state (
451478 self ,
452- serialized_thread_state : dict [str , Any ],
479+ serialized_thread_state : MutableMapping [str , Any ],
453480 ** kwargs : Any ,
454481 ) -> None :
455482 """Deserializes the state from a dictionary into the thread properties.
@@ -460,7 +487,7 @@ async def update_from_thread_state(
460487 Keyword Args:
461488 **kwargs: Additional arguments for deserialization.
462489 """
463- state = AgentThreadState .model_validate (serialized_thread_state )
490+ state = AgentThreadState .from_dict (serialized_thread_state )
464491
465492 if state .service_thread_id is not None :
466493 self .service_thread_id = state .service_thread_id
@@ -470,8 +497,8 @@ async def update_from_thread_state(
470497 if state .chat_message_store_state is None :
471498 return
472499 if self .message_store is not None :
473- await self .message_store .update_from_state (state .chat_message_store_state , ** kwargs )
500+ await self .message_store .add_messages (state .chat_message_store_state . messages , ** kwargs )
474501 # If we don't have a chat message store yet, create an in-memory one.
475502 return
476503 # Create the message store from the default.
477- self .message_store = await ChatMessageStore . deserialize ( state .chat_message_store_state , ** kwargs ) # type: ignore
504+ self .message_store = ChatMessageStore ( messages = state .chat_message_store_state . messages , ** kwargs )
0 commit comments