Skip to content

Commit 87c9d74

Browse files
CopilotTaoChenOSU
andauthored
Python: Fix: Verify types during checkpoint deserialization to prevent marker spoofing (#3243)
* Initial plan * Add validation for reserved keywords in checkpoint encoding/decoding Co-authored-by: TaoChenOSU <[email protected]> * Refactor to eliminate duplicate code in model protocol detection Co-authored-by: TaoChenOSU <[email protected]> * Fix pyright type narrowing issue for dataclass check Co-authored-by: TaoChenOSU <[email protected]> * Add comprehensive unit tests for checkpoint encoding Co-authored-by: TaoChenOSU <[email protected]> * Remove serialization-time reserved keyword validation to fix failing tests The serialization-time validation was too aggressive and blocked legitimate use cases where encoded data was being re-encoded. Security is now enforced only at deserialization time by validating that classes marked with DATACLASS_MARKER are actual dataclasses and classes marked with MODEL_MARKER actually support the model protocol. Co-authored-by: TaoChenOSU <[email protected]> * Apply ruff formatting to checkpoint encoding file Co-authored-by: TaoChenOSU <[email protected]> * Changes before error encountered Co-authored-by: TaoChenOSU <[email protected]> * Revert "Changes before error encountered" This reverts commit f515b88. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: TaoChenOSU <[email protected]> Co-authored-by: Tao Chen <[email protected]>
1 parent 958e6d2 commit 87c9d74

File tree

3 files changed

+558
-7
lines changed

3 files changed

+558
-7
lines changed

python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ def decode_checkpoint_value(value: Any) -> Any:
146146
cls = None
147147

148148
if cls is not None:
149+
# Verify the class actually supports the model protocol
150+
if not _class_supports_model_protocol(cls):
151+
logger.debug(f"Class {type_key} does not support model protocol; returning raw value")
152+
return decoded_payload
149153
if strategy == "to_dict" and hasattr(cls, "from_dict"):
150154
with contextlib.suppress(Exception):
151155
return cls.from_dict(decoded_payload)
@@ -169,6 +173,10 @@ def decode_checkpoint_value(value: Any) -> Any:
169173
if module is None:
170174
module = importlib.import_module(module_name)
171175
cls_dc: Any = getattr(module, class_name)
176+
# Verify the class is actually a dataclass type (not an instance)
177+
if not isinstance(cls_dc, type) or not is_dataclass(cls_dc):
178+
logger.debug(f"Class {type_key_dc} is not a dataclass type; returning raw value")
179+
return decoded_raw
172180
constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw)
173181
if constructed is not None:
174182
return constructed
@@ -188,20 +196,30 @@ def decode_checkpoint_value(value: Any) -> Any:
188196
return value
189197

190198

199+
def _class_supports_model_protocol(cls: type[Any]) -> bool:
200+
"""Check if a class type supports the model serialization protocol.
201+
202+
Checks for pairs of serialization/deserialization methods:
203+
- to_dict/from_dict
204+
- to_json/from_json
205+
"""
206+
has_to_dict = hasattr(cls, "to_dict") and callable(getattr(cls, "to_dict", None))
207+
has_from_dict = hasattr(cls, "from_dict") and callable(getattr(cls, "from_dict", None))
208+
209+
has_to_json = hasattr(cls, "to_json") and callable(getattr(cls, "to_json", None))
210+
has_from_json = hasattr(cls, "from_json") and callable(getattr(cls, "from_json", None))
211+
212+
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
213+
214+
191215
def _supports_model_protocol(obj: object) -> bool:
192216
"""Detect objects that expose dictionary serialization hooks."""
193217
try:
194218
obj_type: type[Any] = type(obj)
195219
except Exception:
196220
return False
197221

198-
has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type]
199-
has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None))
200-
201-
has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type]
202-
has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None))
203-
204-
return (has_to_dict and has_from_dict) or (has_to_json and has_from_json)
222+
return _class_supports_model_protocol(obj_type)
205223

206224

207225
def _import_qualified_name(qualname: str) -> type[Any] | None:

python/packages/core/tests/workflow/test_checkpoint_decode.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from dataclasses import dataclass # noqa: I001
44
from typing import Any, cast
55

6+
67
from agent_framework._workflows._checkpoint_encoding import (
8+
DATACLASS_MARKER,
9+
MODEL_MARKER,
710
decode_checkpoint_value,
811
encode_checkpoint_value,
912
)
@@ -126,3 +129,110 @@ def test_encode_decode_nested_structures() -> None:
126129
assert response.data == "first response"
127130
assert isinstance(response.original_request, SampleRequest)
128131
assert response.original_request.request_id == "req-1"
132+
133+
134+
def test_encode_allows_marker_key_without_value_key() -> None:
135+
"""Test that encoding a dict with only the marker key (no 'value') is allowed."""
136+
dict_with_marker_only = {
137+
MODEL_MARKER: "some.module:FakeClass",
138+
"other_key": "test",
139+
}
140+
encoded = encode_checkpoint_value(dict_with_marker_only)
141+
assert MODEL_MARKER in encoded
142+
assert "other_key" in encoded
143+
144+
145+
def test_encode_allows_value_key_without_marker_key() -> None:
146+
"""Test that encoding a dict with only 'value' key (no marker) is allowed."""
147+
dict_with_value_only = {
148+
"value": {"data": "test"},
149+
"other_key": "test",
150+
}
151+
encoded = encode_checkpoint_value(dict_with_value_only)
152+
assert "value" in encoded
153+
assert "other_key" in encoded
154+
155+
156+
def test_encode_allows_marker_with_value_key() -> None:
157+
"""Test that encoding a dict with marker and 'value' keys is allowed.
158+
159+
This is allowed because legitimate encoded data may contain these keys,
160+
and security is enforced at deserialization time by validating class types.
161+
"""
162+
dict_with_both = {
163+
MODEL_MARKER: "some.module:SomeClass",
164+
"value": {"data": "test"},
165+
"strategy": "to_dict",
166+
}
167+
encoded = encode_checkpoint_value(dict_with_both)
168+
assert MODEL_MARKER in encoded
169+
assert "value" in encoded
170+
171+
172+
class NotADataclass:
173+
"""A regular class that is not a dataclass."""
174+
175+
def __init__(self, value: str) -> None:
176+
self.value = value
177+
178+
def get_value(self) -> str:
179+
return self.value
180+
181+
182+
class NotAModel:
183+
"""A regular class that does not support the model protocol."""
184+
185+
def __init__(self, value: str) -> None:
186+
self.value = value
187+
188+
def get_value(self) -> str:
189+
return self.value
190+
191+
192+
def test_decode_rejects_non_dataclass_with_dataclass_marker() -> None:
193+
"""Test that decode returns raw value when marked class is not a dataclass."""
194+
# Manually construct a payload that claims NotADataclass is a dataclass
195+
fake_payload = {
196+
DATACLASS_MARKER: f"{NotADataclass.__module__}:{NotADataclass.__name__}",
197+
"value": {"value": "test_value"},
198+
}
199+
200+
decoded = decode_checkpoint_value(fake_payload)
201+
202+
# Should return the raw decoded value, not an instance of NotADataclass
203+
assert isinstance(decoded, dict)
204+
assert decoded["value"] == "test_value"
205+
206+
207+
def test_decode_rejects_non_model_with_model_marker() -> None:
208+
"""Test that decode returns raw value when marked class doesn't support model protocol."""
209+
# Manually construct a payload that claims NotAModel supports the model protocol
210+
fake_payload = {
211+
MODEL_MARKER: f"{NotAModel.__module__}:{NotAModel.__name__}",
212+
"strategy": "to_dict",
213+
"value": {"value": "test_value"},
214+
}
215+
216+
decoded = decode_checkpoint_value(fake_payload)
217+
218+
# Should return the raw decoded value, not an instance of NotAModel
219+
assert isinstance(decoded, dict)
220+
assert decoded["value"] == "test_value"
221+
222+
223+
def test_encode_allows_nested_dict_with_marker_keys() -> None:
224+
"""Test that encoding allows nested dicts containing marker patterns.
225+
226+
Security is enforced at deserialization time, not serialization time,
227+
so legitimate encoded data can contain markers at any nesting level.
228+
"""
229+
nested_data = {
230+
"outer": {
231+
MODEL_MARKER: "some.module:SomeClass",
232+
"value": {"data": "test"},
233+
}
234+
}
235+
236+
encoded = encode_checkpoint_value(nested_data)
237+
assert "outer" in encoded
238+
assert MODEL_MARKER in encoded["outer"]

0 commit comments

Comments
 (0)