Skip to content

Commit a3cb8a2

Browse files
authored
feat: allow users to register custom encoders (#296)
1 parent 5f0e954 commit a3cb8a2

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

tests/test_items.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,21 @@ def test_copy_copy():
946946
)
947947
def test_escape_key(key_str, escaped):
948948
assert api.key(key_str).as_string() == escaped
949+
950+
951+
def test_custom_encoders():
952+
import decimal
953+
954+
@api.register_encoder
955+
def encode_decimal(obj):
956+
if isinstance(obj, decimal.Decimal):
957+
return api.float_(str(obj))
958+
raise TypeError
959+
960+
assert api.item(decimal.Decimal("1.23")).as_string() == "1.23"
961+
962+
with pytest.raises(TypeError):
963+
api.item(object())
964+
965+
assert api.dumps({"foo": decimal.Decimal("1.23")}) == "foo = 1.23\n"
966+
api.unregister_encoder(encode_decimal)

tomlkit/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from tomlkit.api import loads
1919
from tomlkit.api import nl
2020
from tomlkit.api import parse
21+
from tomlkit.api import register_encoder
2122
from tomlkit.api import string
2223
from tomlkit.api import table
2324
from tomlkit.api import time
25+
from tomlkit.api import unregister_encoder
2426
from tomlkit.api import value
2527
from tomlkit.api import ws
2628

@@ -52,4 +54,6 @@
5254
"TOMLDocument",
5355
"value",
5456
"ws",
57+
"register_encoder",
58+
"unregister_encoder",
5559
]

tomlkit/api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import datetime as _datetime
45

56
from collections.abc import Mapping
67
from typing import IO
78
from typing import Iterable
9+
from typing import TypeVar
810

911
from tomlkit._utils import parse_rfc3339
1012
from tomlkit.container import Container
1113
from tomlkit.exceptions import UnexpectedCharError
14+
from tomlkit.items import CUSTOM_ENCODERS
1215
from tomlkit.items import AoT
1316
from tomlkit.items import Array
1417
from tomlkit.items import Bool
1518
from tomlkit.items import Comment
1619
from tomlkit.items import Date
1720
from tomlkit.items import DateTime
1821
from tomlkit.items import DottedKey
22+
from tomlkit.items import Encoder
1923
from tomlkit.items import Float
2024
from tomlkit.items import InlineTable
2125
from tomlkit.items import Integer
@@ -284,3 +288,21 @@ def nl() -> Whitespace:
284288
def comment(string: str) -> Comment:
285289
"""Create a comment item."""
286290
return Comment(Trivia(comment_ws=" ", comment="# " + string))
291+
292+
293+
E = TypeVar("E", bound=Encoder)
294+
295+
296+
def register_encoder(encoder: E) -> E:
297+
"""Add a custom encoder, which should be a function that will be called
298+
if the value can't otherwise be converted. It should takes a single value
299+
and return a TOMLKit item or raise a ``TypeError``.
300+
"""
301+
CUSTOM_ENCODERS.append(encoder)
302+
return encoder
303+
304+
305+
def unregister_encoder(encoder: Encoder) -> None:
306+
"""Unregister a custom encoder."""
307+
with contextlib.suppress(ValueError):
308+
CUSTOM_ENCODERS.remove(encoder)

tomlkit/items.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from enum import Enum
1414
from typing import TYPE_CHECKING
1515
from typing import Any
16+
from typing import Callable
1617
from typing import Collection
1718
from typing import Iterable
1819
from typing import Iterator
@@ -57,6 +58,15 @@ class _CustomDict(MutableMapping, dict):
5758

5859

5960
ItemT = TypeVar("ItemT", bound="Item")
61+
Encoder = Callable[[Any], "Item"]
62+
CUSTOM_ENCODERS: list[Encoder] = []
63+
64+
65+
class _ConvertError(TypeError, ValueError):
66+
"""An internal error raised when item() fails to convert a value.
67+
It should be a TypeError, but due to historical reasons
68+
it needs to subclass ValueError as well.
69+
"""
6070

6171

6272
@overload
@@ -218,8 +228,20 @@ def item(value: Any, _parent: Item | None = None, _sort_keys: bool = False) -> I
218228
Trivia(),
219229
value.isoformat(),
220230
)
231+
else:
232+
for encoder in CUSTOM_ENCODERS:
233+
try:
234+
rv = encoder(value)
235+
except TypeError:
236+
pass
237+
else:
238+
if not isinstance(rv, Item):
239+
raise _ConvertError(
240+
f"Custom encoder returned {type(rv)}, not a subclass of Item"
241+
)
242+
return rv
221243

222-
raise ValueError(f"Invalid type {type(value)}")
244+
raise _ConvertError(f"Invalid type {type(value)}")
223245

224246

225247
class StringType(Enum):

0 commit comments

Comments
 (0)