Skip to content

Commit 985f5d4

Browse files
added support for formatting examples (#515)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 8cd58ec commit 985f5d4

File tree

5 files changed

+343
-12
lines changed

5 files changed

+343
-12
lines changed

docs/api/format_as_xml.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# `pydantic_ai.format_as_xml`
2+
3+
::: pydantic_ai.format_as_xml

examples/pydantic_ai_examples/sql_gen.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing_extensions import TypeAlias
2727

2828
from pydantic_ai import Agent, ModelRetry, RunContext
29+
from pydantic_ai.format_as_xml import format_as_xml
2930

3031
# 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured
3132
logfire.configure(send_to_logfire='if-token-present')
@@ -50,6 +51,24 @@
5051
service_name text
5152
);
5253
"""
54+
SQL_EXAMPLES = [
55+
{
56+
'request': 'show me records where foobar is false',
57+
'response': "SELECT * FROM records WHERE attributes->>'foobar' = false",
58+
},
59+
{
60+
'request': 'show me records where attributes include the key "foobar"',
61+
'response': "SELECT * FROM records WHERE attributes ? 'foobar'",
62+
},
63+
{
64+
'request': 'show me records from yesterday',
65+
'response': "SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'",
66+
},
67+
{
68+
'request': 'show me error records with the tag "foobar"',
69+
'response': "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)",
70+
},
71+
]
5372

5473

5574
@dataclass
@@ -93,18 +112,7 @@ async def system_prompt() -> str:
93112
94113
today's date = {date.today()}
95114
96-
Example
97-
request: show me records where foobar is false
98-
response: SELECT * FROM records WHERE attributes->>'foobar' = false
99-
Example
100-
request: show me records where attributes include the key "foobar"
101-
response: SELECT * FROM records WHERE attributes ? 'foobar'
102-
Example
103-
request: show me records from yesterday
104-
response: SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day'
105-
Example
106-
request: show me error records with the tag "foobar"
107-
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)
115+
{format_as_xml(SQL_EXAMPLES)}
108116
"""
109117

110118

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ nav:
4444
- api/exceptions.md
4545
- api/settings.md
4646
- api/usage.md
47+
- api/format_as_xml.md
4748
- api/models/base.md
4849
- api/models/openai.md
4950
- api/models/anthropic.md
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations as _annotations
2+
3+
from collections.abc import Iterable, Iterator, Mapping
4+
from dataclasses import asdict, dataclass, is_dataclass
5+
from datetime import date
6+
from typing import Any
7+
from xml.etree import ElementTree
8+
9+
from pydantic import BaseModel
10+
11+
__all__ = ('format_as_xml',)
12+
13+
14+
def format_as_xml(
15+
obj: Any,
16+
root_tag: str = 'examples',
17+
item_tag: str = 'example',
18+
include_root_tag: bool = True,
19+
none_str: str = 'null',
20+
indent: str | None = ' ',
21+
) -> str:
22+
"""Format a Python object as XML.
23+
24+
This is useful since LLMs often find it easier to read semi-structured data (e.g. examples) as XML,
25+
rather than JSON etc.
26+
27+
Supports: `str`, `bytes`, `bytearray`, `bool`, `int`, `float`, `date`, `datetime`, `Mapping`,
28+
`Iterable`, `dataclass`, and `BaseModel`.
29+
30+
Args:
31+
obj: Python Object to serialize to XML.
32+
root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag.
33+
item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name
34+
for dataclasses and Pydantic models.
35+
include_root_tag: Whether to include the root tag in the output
36+
(The root tag is always included if it includes a body - e.g. when the input is a simple value).
37+
none_str: String to use for `None` values.
38+
indent: Indentation string to use for pretty printing.
39+
40+
Returns: XML representation of the object.
41+
42+
Example:
43+
```python {title="format_as_xml_example.py" lint="skip"}
44+
from pydantic_ai.format_as_xml import format_as_xml
45+
46+
print(format_as_xml({'name': 'John', 'height': 6, 'weight': 200}, root_tag='user'))
47+
'''
48+
<user>
49+
<name>John</name>
50+
<height>6</height>
51+
<weight>200</weight>
52+
</user>
53+
'''
54+
```
55+
"""
56+
el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag)
57+
if not include_root_tag and el.text is None:
58+
join = '' if indent is None else '\n'
59+
return join.join(_rootless_xml_elements(el, indent))
60+
else:
61+
if indent is not None:
62+
ElementTree.indent(el, space=indent)
63+
return ElementTree.tostring(el, encoding='unicode')
64+
65+
66+
@dataclass
67+
class _ToXml:
68+
item_tag: str
69+
none_str: str
70+
71+
def to_xml(self, value: Any, tag: str | None) -> ElementTree.Element:
72+
element = ElementTree.Element(self.item_tag if tag is None else tag)
73+
if value is None:
74+
element.text = self.none_str
75+
elif isinstance(value, str):
76+
element.text = value
77+
elif isinstance(value, (bytes, bytearray)):
78+
element.text = value.decode(errors='ignore')
79+
elif isinstance(value, (bool, int, float)):
80+
element.text = str(value)
81+
elif isinstance(value, date):
82+
element.text = value.isoformat()
83+
elif isinstance(value, Mapping):
84+
self._mapping_to_xml(element, value) # pyright: ignore[reportUnknownArgumentType]
85+
elif is_dataclass(value) and not isinstance(value, type):
86+
if tag is None:
87+
element = ElementTree.Element(value.__class__.__name__)
88+
dc_dict = asdict(value)
89+
self._mapping_to_xml(element, dc_dict)
90+
elif isinstance(value, BaseModel):
91+
if tag is None:
92+
element = ElementTree.Element(value.__class__.__name__)
93+
self._mapping_to_xml(element, value.model_dump(mode='python'))
94+
elif isinstance(value, Iterable):
95+
for item in value: # pyright: ignore[reportUnknownVariableType]
96+
item_el = self.to_xml(item, None)
97+
element.append(item_el)
98+
else:
99+
raise TypeError(f'Unsupported type for XML formatting: {type(value)}')
100+
return element
101+
102+
def _mapping_to_xml(self, element: ElementTree.Element, mapping: Mapping[Any, Any]) -> None:
103+
for key, value in mapping.items():
104+
if isinstance(key, int):
105+
key = str(key)
106+
elif not isinstance(key, str):
107+
raise TypeError(f'Unsupported key type for XML formatting: {type(key)}, only str and int are allowed')
108+
element.append(self.to_xml(value, key))
109+
110+
111+
def _rootless_xml_elements(root: ElementTree.Element, indent: str | None) -> Iterator[str]:
112+
for sub_element in root:
113+
if indent is not None:
114+
ElementTree.indent(sub_element, space=indent)
115+
yield ElementTree.tostring(sub_element, encoding='unicode')

tests/test_format_as_xml.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from dataclasses import dataclass
2+
from datetime import date, datetime
3+
from typing import Any
4+
5+
import pytest
6+
from inline_snapshot import snapshot
7+
from pydantic import BaseModel
8+
9+
from pydantic_ai.format_as_xml import format_as_xml
10+
11+
12+
@dataclass
13+
class ExampleDataclass:
14+
name: str
15+
age: int
16+
17+
18+
class ExamplePydanticModel(BaseModel):
19+
name: str
20+
age: int
21+
22+
23+
@pytest.mark.parametrize(
24+
'input_obj,output',
25+
[
26+
pytest.param('a string', snapshot('<examples>a string</examples>'), id='string'),
27+
pytest.param(42, snapshot('<examples>42</examples>'), id='int'),
28+
pytest.param(None, snapshot('<examples>null</examples>'), id='null'),
29+
pytest.param(
30+
ExampleDataclass(name='John', age=42),
31+
snapshot("""\
32+
<examples>
33+
<name>John</name>
34+
<age>42</age>
35+
</examples>\
36+
"""),
37+
id='dataclass',
38+
),
39+
pytest.param(
40+
ExamplePydanticModel(name='John', age=42),
41+
snapshot("""\
42+
<examples>
43+
<name>John</name>
44+
<age>42</age>
45+
</examples>\
46+
"""),
47+
id='pydantic model',
48+
),
49+
pytest.param(
50+
[ExampleDataclass(name='John', age=42)],
51+
snapshot("""\
52+
<examples>
53+
<ExampleDataclass>
54+
<name>John</name>
55+
<age>42</age>
56+
</ExampleDataclass>
57+
</examples>\
58+
"""),
59+
id='list[dataclass]',
60+
),
61+
pytest.param(
62+
[ExamplePydanticModel(name='John', age=42)],
63+
snapshot("""\
64+
<examples>
65+
<ExamplePydanticModel>
66+
<name>John</name>
67+
<age>42</age>
68+
</ExamplePydanticModel>
69+
</examples>\
70+
"""),
71+
id='list[pydantic model]',
72+
),
73+
pytest.param(
74+
[1, 2, 3],
75+
snapshot("""\
76+
<examples>
77+
<example>1</example>
78+
<example>2</example>
79+
<example>3</example>
80+
</examples>\
81+
"""),
82+
id='list[int]',
83+
),
84+
pytest.param(
85+
(1, 'x'),
86+
snapshot("""\
87+
<examples>
88+
<example>1</example>
89+
<example>x</example>
90+
</examples>\
91+
"""),
92+
id='tuple[int,str]',
93+
),
94+
pytest.param(
95+
[[1, 2], [3]],
96+
snapshot("""\
97+
<examples>
98+
<example>
99+
<example>1</example>
100+
<example>2</example>
101+
</example>
102+
<example>
103+
<example>3</example>
104+
</example>
105+
</examples>\
106+
"""),
107+
id='list[list[int]]',
108+
),
109+
pytest.param(
110+
{'x': 1, 'y': 3, 3: 'z', 4: {'a': -1, 'b': -2}},
111+
snapshot("""\
112+
<examples>
113+
<x>1</x>
114+
<y>3</y>
115+
<3>z</3>
116+
<4>
117+
<a>-1</a>
118+
<b>-2</b>
119+
</4>
120+
</examples>\
121+
"""),
122+
id='dict',
123+
),
124+
],
125+
)
126+
def test(input_obj: Any, output: str):
127+
assert format_as_xml(input_obj) == output
128+
129+
130+
@pytest.mark.parametrize(
131+
'input_obj,output',
132+
[
133+
pytest.param('a string', snapshot('<examples>a string</examples>'), id='string'),
134+
pytest.param('a <ex>foo</ex>', snapshot('<examples>a &lt;ex&gt;foo&lt;/ex&gt;</examples>'), id='string'),
135+
pytest.param(42, snapshot('<examples>42</examples>'), id='int'),
136+
pytest.param(
137+
[1, 2, 3],
138+
snapshot("""\
139+
<example>1</example>
140+
<example>2</example>
141+
<example>3</example>\
142+
"""),
143+
id='list[int]',
144+
),
145+
pytest.param(
146+
[[1, 2], [3]],
147+
snapshot("""\
148+
<example>
149+
<example>1</example>
150+
<example>2</example>
151+
</example>
152+
<example>
153+
<example>3</example>
154+
</example>\
155+
"""),
156+
id='list[list[int]]',
157+
),
158+
pytest.param(
159+
{'binary': b'my bytes', 'barray': bytearray(b'foo')},
160+
snapshot("""\
161+
<binary>my bytes</binary>
162+
<barray>foo</barray>\
163+
"""),
164+
id='dict[str, bytes]',
165+
),
166+
pytest.param(
167+
[datetime(2025, 1, 1, 12, 13), date(2025, 1, 2)],
168+
snapshot("""\
169+
<example>2025-01-01T12:13:00</example>
170+
<example>2025-01-02</example>\
171+
"""),
172+
id='list[date]',
173+
),
174+
],
175+
)
176+
def test_no_root(input_obj: Any, output: str):
177+
assert format_as_xml(input_obj, include_root_tag=False) == output
178+
179+
180+
def test_no_indent():
181+
assert format_as_xml([1, 2, 3], indent=None) == snapshot(
182+
'<examples><example>1</example><example>2</example><example>3</example></examples>'
183+
)
184+
assert format_as_xml([1, 2, 3], indent=None, include_root_tag=False) == snapshot(
185+
'<example>1</example><example>2</example><example>3</example>'
186+
)
187+
188+
189+
def test_invalid_value():
190+
with pytest.raises(TypeError, match='Unsupported type'):
191+
format_as_xml(object())
192+
193+
194+
def test_invalid_key():
195+
with pytest.raises(TypeError, match='Unsupported key type for XML formatting'):
196+
format_as_xml({(1, 2): 42})
197+
198+
199+
def test_set():
200+
assert '<example>1</example>' in format_as_xml({1, 2, 3})
201+
202+
203+
def test_custom_null():
204+
assert format_as_xml(None, none_str='nil') == snapshot('<examples>nil</examples>')

0 commit comments

Comments
 (0)