Skip to content

Commit 6f92ee4

Browse files
authored
Merge branch 'main' into devui_move_samples
2 parents 5ce34bf + a36e183 commit 6f92ee4

File tree

5 files changed

+119
-68
lines changed

5 files changed

+119
-68
lines changed

python/packages/core/agent_framework/_serialization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,17 +212,18 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True)
212212

213213
return result
214214

215-
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> str:
215+
def to_json(self, *, exclude: set[str] | None = None, exclude_none: bool = True, **kwargs: Any) -> str:
216216
"""Convert the instance to a JSON string.
217217
218218
Keyword Args:
219219
exclude: The set of field names to exclude from serialization.
220220
exclude_none: Whether to exclude None values from the output. Defaults to True.
221+
**kwargs: passed through to the json.dumps method.
221222
222223
Returns:
223224
JSON string representation of the instance.
224225
"""
225-
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none))
226+
return json.dumps(self.to_dict(exclude=exclude, exclude_none=exclude_none), **kwargs)
226227

227228
@classmethod
228229
def from_dict(

python/packages/core/agent_framework/_threads.py

Lines changed: 71 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
from collections.abc import Sequence
3+
from collections.abc import MutableMapping, Sequence
44
from typing import Any, Protocol, TypeVar
55

6-
from pydantic import BaseModel, ConfigDict, model_validator
7-
86
from ._memory import AggregateContextProvider
7+
from ._serialization import SerializationMixin
98
from ._types import ChatMessage
109
from .exceptions import AgentThreadException
1110

@@ -73,7 +72,9 @@ async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
7372
...
7473

7574
@classmethod
76-
async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatMessageStoreProtocol":
75+
async def deserialize(
76+
cls, serialized_store_state: MutableMapping[str, Any], **kwargs: Any
77+
) -> "ChatMessageStoreProtocol":
7778
"""Creates a new instance of the store from previously serialized state.
7879
7980
This method, together with ``serialize()`` can be used to save and load messages from a persistent store
@@ -90,7 +91,7 @@ async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "ChatM
9091
"""
9192
...
9293

93-
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
94+
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
9495
"""Update the current ChatMessageStore instance from serialized state data.
9596
9697
Args:
@@ -101,7 +102,7 @@ async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) ->
101102
"""
102103
...
103104

104-
async def serialize(self, **kwargs: Any) -> Any:
105+
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
105106
"""Serializes the current object's state.
106107
107108
This method, together with ``deserialize()`` can be used to save and load messages from a persistent store
@@ -116,40 +117,66 @@ async def serialize(self, **kwargs: Any) -> Any:
116117
...
117118

118119

119-
class ChatMessageStoreState(BaseModel):
120+
class ChatMessageStoreState(SerializationMixin):
120121
"""State model for serializing and deserializing chat message store data.
121122
122123
Attributes:
123124
messages: List of chat messages stored in the message store.
124125
"""
125126

126-
messages: list[ChatMessage]
127-
128-
model_config = ConfigDict(arbitrary_types_allowed=True)
129-
127+
def __init__(
128+
self,
129+
messages: Sequence[ChatMessage] | Sequence[MutableMapping[str, Any]] | None = None,
130+
**kwargs: Any,
131+
) -> None:
132+
"""Create the store state.
130133
131-
class AgentThreadState(BaseModel):
132-
"""State model for serializing and deserializing thread information.
134+
Args:
135+
messages: a list of messages or a list of the dict representation of messages.
133136
134-
Attributes:
135-
service_thread_id: Optional ID of the thread managed by the agent service.
136-
chat_message_store_state: Optional serialized state of the chat message store.
137-
"""
137+
Keyword Args:
138+
**kwargs: not used for this, but might be used by subclasses.
138139
139-
service_thread_id: str | None = None
140-
chat_message_store_state: ChatMessageStoreState | None = None
140+
"""
141+
if not messages:
142+
self.messages: list[ChatMessage] = []
143+
if not isinstance(messages, list):
144+
raise TypeError("Messages should be a list")
145+
new_messages: list[ChatMessage] = []
146+
for msg in messages:
147+
if isinstance(msg, ChatMessage):
148+
new_messages.append(msg)
149+
else:
150+
new_messages.append(ChatMessage.from_dict(msg))
151+
self.messages = new_messages
152+
153+
154+
class AgentThreadState(SerializationMixin):
155+
"""State model for serializing and deserializing thread information."""
141156

142-
model_config = ConfigDict(arbitrary_types_allowed=True)
157+
def __init__(
158+
self,
159+
*,
160+
service_thread_id: str | None = None,
161+
chat_message_store_state: ChatMessageStoreState | MutableMapping[str, Any] | None = None,
162+
) -> None:
163+
"""Create a AgentThread state.
143164
144-
@model_validator(mode="before")
145-
def validate_only_one(cls, values: dict[str, Any]) -> dict[str, Any]:
146-
if (
147-
isinstance(values, dict)
148-
and values.get("service_thread_id") is not None
149-
and values.get("chat_message_store_state") is not None
150-
):
151-
raise AgentThreadException("Only one of service_thread_id or chat_message_store_state may be set.")
152-
return values
165+
Keyword Args:
166+
service_thread_id: Optional ID of the thread managed by the agent service.
167+
chat_message_store_state: Optional serialized state of the chat message store.
168+
"""
169+
if service_thread_id is not None and chat_message_store_state is not None:
170+
raise AgentThreadException("A thread cannot have both a service_thread_id and a chat_message_store.")
171+
self.service_thread_id = service_thread_id
172+
self.chat_message_store_state: ChatMessageStoreState | None = None
173+
if chat_message_store_state is not None:
174+
if isinstance(chat_message_store_state, dict):
175+
self.chat_message_store_state = ChatMessageStoreState.from_dict(chat_message_store_state)
176+
elif isinstance(chat_message_store_state, ChatMessageStoreState):
177+
self.chat_message_store_state = chat_message_store_state
178+
else:
179+
raise TypeError("Could not parse ChatMessageStoreState.")
153180

154181

155182
TChatMessageStore = TypeVar("TChatMessageStore", bound="ChatMessageStore")
@@ -213,7 +240,7 @@ async def list_messages(self) -> list[ChatMessage]:
213240

214241
@classmethod
215242
async def deserialize(
216-
cls: type[TChatMessageStore], serialized_store_state: Any, **kwargs: Any
243+
cls: type[TChatMessageStore], serialized_store_state: MutableMapping[str, Any], **kwargs: Any
217244
) -> TChatMessageStore:
218245
"""Create a new ChatMessageStore instance from serialized state data.
219246
@@ -226,12 +253,12 @@ async def deserialize(
226253
Returns:
227254
A new ChatMessageStore instance populated with messages from the serialized state.
228255
"""
229-
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
256+
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
230257
if state.messages:
231258
return cls(messages=state.messages)
232259
return cls()
233260

234-
async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None:
261+
async def update_from_state(self, serialized_store_state: MutableMapping[str, Any], **kwargs: Any) -> None:
235262
"""Update the current ChatMessageStore instance from serialized state data.
236263
237264
Args:
@@ -242,11 +269,11 @@ async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) ->
242269
"""
243270
if not serialized_store_state:
244271
return
245-
state = ChatMessageStoreState.model_validate(serialized_store_state, **kwargs)
272+
state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs)
246273
if state.messages:
247274
self.messages = state.messages
248275

249-
async def serialize(self, **kwargs: Any) -> Any:
276+
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
250277
"""Serialize the current store state for persistence.
251278
252279
Keyword Args:
@@ -256,7 +283,7 @@ async def serialize(self, **kwargs: Any) -> Any:
256283
Serialized state data that can be used with deserialize_state.
257284
"""
258285
state = ChatMessageStoreState(messages=self.messages)
259-
return state.model_dump(**kwargs)
286+
return state.to_dict()
260287

261288

262289
TAgentThread = TypeVar("TAgentThread", bound="AgentThread")
@@ -403,12 +430,12 @@ async def serialize(self, **kwargs: Any) -> dict[str, Any]:
403430
state = AgentThreadState(
404431
service_thread_id=self._service_thread_id, chat_message_store_state=chat_message_store_state
405432
)
406-
return state.model_dump()
433+
return state.to_dict(exclude_none=False)
407434

408435
@classmethod
409436
async def deserialize(
410437
cls: type[TAgentThread],
411-
serialized_thread_state: dict[str, Any],
438+
serialized_thread_state: MutableMapping[str, Any],
412439
*,
413440
message_store: ChatMessageStoreProtocol | None = None,
414441
**kwargs: Any,
@@ -426,7 +453,7 @@ async def deserialize(
426453
Returns:
427454
A new AgentThread instance with properties set from the serialized state.
428455
"""
429-
state = AgentThreadState.model_validate(serialized_thread_state)
456+
state = AgentThreadState.from_dict(serialized_thread_state)
430457

431458
if state.service_thread_id is not None:
432459
return cls(service_thread_id=state.service_thread_id)
@@ -437,19 +464,19 @@ async def deserialize(
437464

438465
if message_store is not None:
439466
try:
440-
await message_store.update_from_state(state.chat_message_store_state, **kwargs)
467+
await message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
441468
except Exception as ex:
442469
raise AgentThreadException("Failed to deserialize the provided message store.") from ex
443470
return cls(message_store=message_store)
444471
try:
445-
message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs)
472+
message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)
446473
except Exception as ex:
447474
raise AgentThreadException("Failed to deserialize the message store.") from ex
448475
return cls(message_store=message_store)
449476

450477
async def update_from_thread_state(
451478
self,
452-
serialized_thread_state: dict[str, Any],
479+
serialized_thread_state: MutableMapping[str, Any],
453480
**kwargs: Any,
454481
) -> None:
455482
"""Deserializes the state from a dictionary into the thread properties.
@@ -460,7 +487,7 @@ async def update_from_thread_state(
460487
Keyword Args:
461488
**kwargs: Additional arguments for deserialization.
462489
"""
463-
state = AgentThreadState.model_validate(serialized_thread_state)
490+
state = AgentThreadState.from_dict(serialized_thread_state)
464491

465492
if state.service_thread_id is not None:
466493
self.service_thread_id = state.service_thread_id
@@ -470,8 +497,8 @@ async def update_from_thread_state(
470497
if state.chat_message_store_state is None:
471498
return
472499
if self.message_store is not None:
473-
await self.message_store.update_from_state(state.chat_message_store_state, **kwargs)
500+
await self.message_store.add_messages(state.chat_message_store_state.messages, **kwargs)
474501
# If we don't have a chat message store yet, create an in-memory one.
475502
return
476503
# Create the message store from the default.
477-
self.message_store = await ChatMessageStore.deserialize(state.chat_message_store_state, **kwargs) # type: ignore
504+
self.message_store = ChatMessageStore(messages=state.chat_message_store_state.messages, **kwargs)

python/packages/core/tests/core/test_threads.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,15 @@ async def test_deserialize_with_existing_store(self) -> None:
224224
"""Test _deserialize with existing message store."""
225225
store = MockChatMessageStore()
226226
thread = AgentThread(message_store=store)
227-
serialized_data: dict[str, Any] = {"service_thread_id": None, "chat_message_store_state": {"messages": []}}
227+
serialized_data: dict[str, Any] = {
228+
"service_thread_id": None,
229+
"chat_message_store_state": {"messages": [ChatMessage(role="user", text="test")]},
230+
}
228231

229232
await thread.update_from_thread_state(serialized_data)
230233

231-
assert store._deserialize_calls == 1 # pyright: ignore[reportPrivateUsage]
234+
assert store._messages
235+
assert store._messages[0].text == "test"
232236

233237
async def test_serialize_with_service_thread_id(self) -> None:
234238
"""Test serialize with service_thread_id."""
@@ -268,6 +272,23 @@ async def test_serialize_with_kwargs(self) -> None:
268272

269273
assert store._serialize_calls == 1 # pyright: ignore[reportPrivateUsage]
270274

275+
async def test_serialize_round_trip_messages(self, sample_messages: list[ChatMessage]) -> None:
276+
"""Test a roundtrip of the serialization."""
277+
store = ChatMessageStore(sample_messages)
278+
thread = AgentThread(message_store=store)
279+
new_thread = await AgentThread.deserialize(await thread.serialize())
280+
assert new_thread.message_store is not None
281+
new_messages = await new_thread.message_store.list_messages()
282+
assert len(new_messages) == len(sample_messages)
283+
assert {new.text for new in new_messages} == {orig.text for orig in sample_messages}
284+
285+
async def test_serialize_round_trip_thread_id(self) -> None:
286+
"""Test a roundtrip of the serialization."""
287+
thread = AgentThread(service_thread_id="test-1234")
288+
new_thread = await AgentThread.deserialize(await thread.serialize())
289+
assert new_thread.message_store is None
290+
assert new_thread.service_thread_id == "test-1234"
291+
271292

272293
class TestChatMessageList:
273294
"""Test cases for ChatMessageStore class."""
@@ -377,17 +398,15 @@ def test_init_with_service_thread_id(self) -> None:
377398
def test_init_with_chat_message_store_state(self) -> None:
378399
"""Test AgentThreadState initialization with chat_message_store_state."""
379400
store_data: dict[str, Any] = {"messages": []}
380-
state = AgentThreadState.model_validate({"chat_message_store_state": store_data})
401+
state = AgentThreadState.from_dict({"chat_message_store_state": store_data})
381402

382403
assert state.service_thread_id is None
383404
assert state.chat_message_store_state.messages == []
384405

385406
def test_init_with_both(self) -> None:
386407
"""Test AgentThreadState initialization with both parameters."""
387408
store_data: dict[str, Any] = {"messages": []}
388-
with pytest.raises(
389-
AgentThreadException, match="Only one of service_thread_id or chat_message_store_state may be set"
390-
):
409+
with pytest.raises(AgentThreadException):
391410
AgentThreadState(service_thread_id="test-conv-123", chat_message_store_state=store_data)
392411

393412
def test_init_defaults(self) -> None:

0 commit comments

Comments
 (0)