Skip to content

Commit e07f6a1

Browse files
authored
fix: full methods of integer and float (#307)
1 parent 9e39a63 commit e07f6a1

File tree

5 files changed

+160
-96
lines changed

5 files changed

+160
-96
lines changed

tomlkit/_types.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from typing import Any
5+
from typing import TypeVar
6+
7+
8+
WT = TypeVar("WT", bound="WrapperType")
9+
10+
if TYPE_CHECKING: # pragma: no cover
11+
# Define _CustomList and _CustomDict as a workaround for:
12+
# https://github.com/python/mypy/issues/11427
13+
#
14+
# According to this issue, the typeshed contains a "lie"
15+
# (it adds MutableSequence to the ancestry of list and MutableMapping to
16+
# the ancestry of dict) which completely messes with the type inference for
17+
# Table, InlineTable, Array and Container.
18+
#
19+
# Importing from builtins is preferred over simple assignment, see issues:
20+
# https://github.com/python/mypy/issues/8715
21+
# https://github.com/python/mypy/issues/10068
22+
from builtins import dict as _CustomDict # noqa: N812
23+
from builtins import float as _CustomFloat # noqa: N812
24+
from builtins import int as _CustomInt # noqa: N812
25+
from builtins import list as _CustomList # noqa: N812
26+
from typing import Callable
27+
from typing import Concatenate
28+
from typing import ParamSpec
29+
from typing import Protocol
30+
31+
P = ParamSpec("P")
32+
33+
class WrapperType(Protocol):
34+
def _new(self: WT, value: Any) -> WT:
35+
...
36+
37+
else:
38+
from collections.abc import MutableMapping
39+
from collections.abc import MutableSequence
40+
from numbers import Integral
41+
from numbers import Real
42+
43+
class _CustomList(MutableSequence, list):
44+
"""Adds MutableSequence mixin while pretending to be a builtin list"""
45+
46+
class _CustomDict(MutableMapping, dict):
47+
"""Adds MutableMapping mixin while pretending to be a builtin dict"""
48+
49+
class _CustomInt(Integral, int):
50+
"""Adds Integral mixin while pretending to be a builtin int"""
51+
52+
class _CustomFloat(Real, float):
53+
"""Adds Real mixin while pretending to be a builtin float"""
54+
55+
56+
def wrap_method(
57+
original_method: Callable[Concatenate[WT, P], Any]
58+
) -> Callable[Concatenate[WT, P], Any]:
59+
def wrapper(self: WT, *args: P.args, **kwargs: P.kwargs) -> Any:
60+
result = original_method(self, *args, **kwargs)
61+
if result is NotImplemented:
62+
return result
63+
return self._new(result)
64+
65+
return wrapper

tomlkit/container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Iterator
77

88
from tomlkit._compat import decode
9+
from tomlkit._types import _CustomDict
910
from tomlkit._utils import merge_dicts
1011
from tomlkit.exceptions import KeyAlreadyPresent
1112
from tomlkit.exceptions import NonExistentKey
@@ -19,7 +20,6 @@
1920
from tomlkit.items import Table
2021
from tomlkit.items import Trivia
2122
from tomlkit.items import Whitespace
22-
from tomlkit.items import _CustomDict
2323
from tomlkit.items import item as _item
2424

2525

tomlkit/items.py

Lines changed: 92 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import abc
44
import copy
55
import dataclasses
6+
import math
67
import re
78
import string
9+
import sys
810

911
from datetime import date
1012
from datetime import datetime
@@ -24,42 +26,24 @@
2426

2527
from tomlkit._compat import PY38
2628
from tomlkit._compat import decode
29+
from tomlkit._types import _CustomDict
30+
from tomlkit._types import _CustomFloat
31+
from tomlkit._types import _CustomInt
32+
from tomlkit._types import _CustomList
33+
from tomlkit._types import wrap_method
2734
from tomlkit._utils import CONTROL_CHARS
2835
from tomlkit._utils import escape_string
2936
from tomlkit.exceptions import InvalidStringError
3037

3138

32-
if TYPE_CHECKING: # pragma: no cover
33-
# Define _CustomList and _CustomDict as a workaround for:
34-
# https://github.com/python/mypy/issues/11427
35-
#
36-
# According to this issue, the typeshed contains a "lie"
37-
# (it adds MutableSequence to the ancestry of list and MutableMapping to
38-
# the ancestry of dict) which completely messes with the type inference for
39-
# Table, InlineTable, Array and Container.
40-
#
41-
# Importing from builtins is preferred over simple assignment, see issues:
42-
# https://github.com/python/mypy/issues/8715
43-
# https://github.com/python/mypy/issues/10068
44-
from builtins import dict as _CustomDict # noqa: N812, TC004
45-
from builtins import list as _CustomList # noqa: N812, TC004
46-
47-
# Allow type annotations but break circular imports
39+
if TYPE_CHECKING:
4840
from tomlkit import container
49-
else:
50-
from collections.abc import MutableMapping
51-
from collections.abc import MutableSequence
52-
53-
class _CustomList(MutableSequence, list):
54-
"""Adds MutableSequence mixin while pretending to be a builtin list"""
55-
56-
class _CustomDict(MutableMapping, dict):
57-
"""Adds MutableMapping mixin while pretending to be a builtin dict"""
5841

5942

6043
ItemT = TypeVar("ItemT", bound="Item")
6144
Encoder = Callable[[Any], "Item"]
6245
CUSTOM_ENCODERS: list[Encoder] = []
46+
AT = TypeVar("AT", bound="AbstractTable")
6347

6448

6549
class _ConvertError(TypeError, ValueError):
@@ -456,7 +440,7 @@ def __eq__(self, other: Any) -> bool:
456440
class DottedKey(Key):
457441
def __init__(
458442
self,
459-
keys: Iterable[Key],
443+
keys: Iterable[SingleKey],
460444
sep: str | None = None,
461445
original: str | None = None,
462446
) -> None:
@@ -606,25 +590,27 @@ def __str__(self) -> str:
606590
return f"{self._trivia.indent}{decode(self._trivia.comment)}"
607591

608592

609-
class Integer(int, Item):
593+
class Integer(Item, _CustomInt):
610594
"""
611595
An integer literal.
612596
"""
613597

614598
def __new__(cls, value: int, trivia: Trivia, raw: str) -> Integer:
615-
return super().__new__(cls, value)
599+
return int.__new__(cls, value)
616600

617-
def __init__(self, _: int, trivia: Trivia, raw: str) -> None:
601+
def __init__(self, value: int, trivia: Trivia, raw: str) -> None:
618602
super().__init__(trivia)
619-
603+
self._original = value
620604
self._raw = raw
621605
self._sign = False
622606

623607
if re.match(r"^[+\-]\d+$", raw):
624608
self._sign = True
625609

626610
def unwrap(self) -> int:
627-
return int(self)
611+
return self._original
612+
613+
__int__ = unwrap
628614

629615
@property
630616
def discriminant(self) -> int:
@@ -638,30 +624,6 @@ def value(self) -> int:
638624
def as_string(self) -> str:
639625
return self._raw
640626

641-
def __add__(self, other):
642-
result = super().__add__(other)
643-
if result is NotImplemented:
644-
return result
645-
return self._new(result)
646-
647-
def __radd__(self, other):
648-
result = super().__radd__(other)
649-
if result is NotImplemented:
650-
return result
651-
return self._new(result)
652-
653-
def __sub__(self, other):
654-
result = super().__sub__(other)
655-
if result is NotImplemented:
656-
return result
657-
return self._new(result)
658-
659-
def __rsub__(self, other):
660-
result = super().__rsub__(other)
661-
if result is NotImplemented:
662-
return result
663-
return self._new(result)
664-
665627
def _new(self, result):
666628
raw = str(result)
667629
if self._sign:
@@ -673,26 +635,63 @@ def _new(self, result):
673635
def _getstate(self, protocol=3):
674636
return int(self), self._trivia, self._raw
675637

676-
677-
class Float(float, Item):
638+
# int methods
639+
__abs__ = wrap_method(int.__abs__)
640+
__add__ = wrap_method(int.__add__)
641+
__and__ = wrap_method(int.__and__)
642+
__ceil__ = wrap_method(int.__ceil__)
643+
__eq__ = int.__eq__
644+
__floor__ = wrap_method(int.__floor__)
645+
__floordiv__ = wrap_method(int.__floordiv__)
646+
__invert__ = wrap_method(int.__invert__)
647+
__le__ = int.__le__
648+
__lshift__ = wrap_method(int.__lshift__)
649+
__lt__ = int.__lt__
650+
__mod__ = wrap_method(int.__mod__)
651+
__mul__ = wrap_method(int.__mul__)
652+
__neg__ = wrap_method(int.__neg__)
653+
__or__ = wrap_method(int.__or__)
654+
__pos__ = wrap_method(int.__pos__)
655+
__pow__ = wrap_method(int.__pow__)
656+
__radd__ = wrap_method(int.__radd__)
657+
__rand__ = wrap_method(int.__rand__)
658+
__rfloordiv__ = wrap_method(int.__rfloordiv__)
659+
__rlshift__ = wrap_method(int.__rlshift__)
660+
__rmod__ = wrap_method(int.__rmod__)
661+
__rmul__ = wrap_method(int.__rmul__)
662+
__ror__ = wrap_method(int.__ror__)
663+
__round__ = wrap_method(int.__round__)
664+
__rpow__ = wrap_method(int.__rpow__)
665+
__rrshift__ = wrap_method(int.__rrshift__)
666+
__rshift__ = wrap_method(int.__rshift__)
667+
__rtruediv__ = wrap_method(int.__rtruediv__)
668+
__rxor__ = wrap_method(int.__rxor__)
669+
__truediv__ = wrap_method(int.__truediv__)
670+
__trunc__ = wrap_method(int.__trunc__)
671+
__xor__ = wrap_method(int.__xor__)
672+
673+
674+
class Float(Item, _CustomFloat):
678675
"""
679676
A float literal.
680677
"""
681678

682-
def __new__(cls, value: float, trivia: Trivia, raw: str) -> Integer:
683-
return super().__new__(cls, value)
679+
def __new__(cls, value: float, trivia: Trivia, raw: str) -> Float:
680+
return float.__new__(cls, value)
684681

685-
def __init__(self, _: float, trivia: Trivia, raw: str) -> None:
682+
def __init__(self, value: float, trivia: Trivia, raw: str) -> None:
686683
super().__init__(trivia)
687-
684+
self._original = value
688685
self._raw = raw
689686
self._sign = False
690687

691688
if re.match(r"^[+\-].+$", raw):
692689
self._sign = True
693690

694691
def unwrap(self) -> float:
695-
return float(self)
692+
return self._original
693+
694+
__float__ = unwrap
696695

697696
@property
698697
def discriminant(self) -> int:
@@ -706,32 +705,6 @@ def value(self) -> float:
706705
def as_string(self) -> str:
707706
return self._raw
708707

709-
def __add__(self, other):
710-
result = super().__add__(other)
711-
712-
return self._new(result)
713-
714-
def __radd__(self, other):
715-
result = super().__radd__(other)
716-
717-
if isinstance(other, Float):
718-
return self._new(result)
719-
720-
return result
721-
722-
def __sub__(self, other):
723-
result = super().__sub__(other)
724-
725-
return self._new(result)
726-
727-
def __rsub__(self, other):
728-
result = super().__rsub__(other)
729-
730-
if isinstance(other, Float):
731-
return self._new(result)
732-
733-
return result
734-
735708
def _new(self, result):
736709
raw = str(result)
737710

@@ -744,6 +717,35 @@ def _new(self, result):
744717
def _getstate(self, protocol=3):
745718
return float(self), self._trivia, self._raw
746719

720+
# float methods
721+
__abs__ = wrap_method(float.__abs__)
722+
__add__ = wrap_method(float.__add__)
723+
__eq__ = float.__eq__
724+
__floordiv__ = wrap_method(float.__floordiv__)
725+
__le__ = float.__le__
726+
__lt__ = float.__lt__
727+
__mod__ = wrap_method(float.__mod__)
728+
__mul__ = wrap_method(float.__mul__)
729+
__neg__ = wrap_method(float.__neg__)
730+
__pos__ = wrap_method(float.__pos__)
731+
__pow__ = wrap_method(float.__pow__)
732+
__radd__ = wrap_method(float.__radd__)
733+
__rfloordiv__ = wrap_method(float.__rfloordiv__)
734+
__rmod__ = wrap_method(float.__rmod__)
735+
__rmul__ = wrap_method(float.__rmul__)
736+
__round__ = wrap_method(float.__round__)
737+
__rpow__ = wrap_method(float.__rpow__)
738+
__rtruediv__ = wrap_method(float.__rtruediv__)
739+
__truediv__ = wrap_method(float.__truediv__)
740+
__trunc__ = float.__trunc__
741+
742+
if sys.version_info >= (3, 9):
743+
__ceil__ = float.__ceil__
744+
__floor__ = float.__floor__
745+
else:
746+
__ceil__ = math.ceil
747+
__floor__ = math.floor
748+
747749

748750
class Bool(Item):
749751
"""
@@ -1410,9 +1412,6 @@ def _getstate(self, protocol=3):
14101412
return list(self._iter_items()), self._trivia, self._multiline
14111413

14121414

1413-
AT = TypeVar("AT", bound="AbstractTable")
1414-
1415-
14161415
class AbstractTable(Item, _CustomDict):
14171416
"""Common behaviour of both :class:`Table` and :class:`InlineTable`"""
14181417

@@ -1452,11 +1451,11 @@ def append(self, key, value):
14521451
raise NotImplementedError
14531452

14541453
@overload
1455-
def add(self: AT, value: Comment | Whitespace) -> AT:
1454+
def add(self: AT, key: Comment | Whitespace) -> AT:
14561455
...
14571456

14581457
@overload
1459-
def add(self: AT, key: Key | str, value: Any) -> AT:
1458+
def add(self: AT, key: Key | str, value: Any = ...) -> AT:
14601459
...
14611460

14621461
def add(self, key, value=None):

tomlkit/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Parser:
6060
Parser for TOML documents.
6161
"""
6262

63-
def __init__(self, string: str) -> None:
63+
def __init__(self, string: str | bytes) -> None:
6464
# Input to parse
6565
self._src = Source(decode(string))
6666

tomlkit/source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, source: Source) -> None:
5050
def __call__(self, *args, **kwargs):
5151
return _State(self._source, *args, **kwargs)
5252

53-
def __enter__(self) -> None:
53+
def __enter__(self) -> _State:
5454
state = self()
5555
self._states.append(state)
5656
return state.__enter__()

0 commit comments

Comments
 (0)