33import abc
44import copy
55import dataclasses
6+ import math
67import re
78import string
9+ import sys
810
911from datetime import date
1012from datetime import datetime
2426
2527from tomlkit ._compat import PY38
2628from tomlkit ._compat import decode
29+ from tomlkit ._types import _CustomDict
30+ from tomlkit ._types import _CustomFloat
31+ from tomlkit ._types import _CustomInt
32+ from tomlkit ._types import _CustomList
33+ from tomlkit ._types import wrap_method
2734from tomlkit ._utils import CONTROL_CHARS
2835from tomlkit ._utils import escape_string
2936from tomlkit .exceptions import InvalidStringError
3037
3138
32- if TYPE_CHECKING : # pragma: no cover
33- # Define _CustomList and _CustomDict as a workaround for:
34- # https://github.com/python/mypy/issues/11427
35- #
36- # According to this issue, the typeshed contains a "lie"
37- # (it adds MutableSequence to the ancestry of list and MutableMapping to
38- # the ancestry of dict) which completely messes with the type inference for
39- # Table, InlineTable, Array and Container.
40- #
41- # Importing from builtins is preferred over simple assignment, see issues:
42- # https://github.com/python/mypy/issues/8715
43- # https://github.com/python/mypy/issues/10068
44- from builtins import dict as _CustomDict # noqa: N812, TC004
45- from builtins import list as _CustomList # noqa: N812, TC004
46-
47- # Allow type annotations but break circular imports
39+ if TYPE_CHECKING :
4840 from tomlkit import container
49- else :
50- from collections .abc import MutableMapping
51- from collections .abc import MutableSequence
52-
53- class _CustomList (MutableSequence , list ):
54- """Adds MutableSequence mixin while pretending to be a builtin list"""
55-
56- class _CustomDict (MutableMapping , dict ):
57- """Adds MutableMapping mixin while pretending to be a builtin dict"""
5841
5942
6043ItemT = TypeVar ("ItemT" , bound = "Item" )
6144Encoder = Callable [[Any ], "Item" ]
6245CUSTOM_ENCODERS : list [Encoder ] = []
46+ AT = TypeVar ("AT" , bound = "AbstractTable" )
6347
6448
6549class _ConvertError (TypeError , ValueError ):
@@ -456,7 +440,7 @@ def __eq__(self, other: Any) -> bool:
456440class DottedKey (Key ):
457441 def __init__ (
458442 self ,
459- keys : Iterable [Key ],
443+ keys : Iterable [SingleKey ],
460444 sep : str | None = None ,
461445 original : str | None = None ,
462446 ) -> None :
@@ -606,25 +590,27 @@ def __str__(self) -> str:
606590 return f"{ self ._trivia .indent } { decode (self ._trivia .comment )} "
607591
608592
609- class Integer (int , Item ):
593+ class Integer (Item , _CustomInt ):
610594 """
611595 An integer literal.
612596 """
613597
614598 def __new__ (cls , value : int , trivia : Trivia , raw : str ) -> Integer :
615- return super () .__new__ (cls , value )
599+ return int .__new__ (cls , value )
616600
617- def __init__ (self , _ : int , trivia : Trivia , raw : str ) -> None :
601+ def __init__ (self , value : int , trivia : Trivia , raw : str ) -> None :
618602 super ().__init__ (trivia )
619-
603+ self . _original = value
620604 self ._raw = raw
621605 self ._sign = False
622606
623607 if re .match (r"^[+\-]\d+$" , raw ):
624608 self ._sign = True
625609
626610 def unwrap (self ) -> int :
627- return int (self )
611+ return self ._original
612+
613+ __int__ = unwrap
628614
629615 @property
630616 def discriminant (self ) -> int :
@@ -638,30 +624,6 @@ def value(self) -> int:
638624 def as_string (self ) -> str :
639625 return self ._raw
640626
641- def __add__ (self , other ):
642- result = super ().__add__ (other )
643- if result is NotImplemented :
644- return result
645- return self ._new (result )
646-
647- def __radd__ (self , other ):
648- result = super ().__radd__ (other )
649- if result is NotImplemented :
650- return result
651- return self ._new (result )
652-
653- def __sub__ (self , other ):
654- result = super ().__sub__ (other )
655- if result is NotImplemented :
656- return result
657- return self ._new (result )
658-
659- def __rsub__ (self , other ):
660- result = super ().__rsub__ (other )
661- if result is NotImplemented :
662- return result
663- return self ._new (result )
664-
665627 def _new (self , result ):
666628 raw = str (result )
667629 if self ._sign :
@@ -673,26 +635,63 @@ def _new(self, result):
673635 def _getstate (self , protocol = 3 ):
674636 return int (self ), self ._trivia , self ._raw
675637
676-
677- class Float (float , Item ):
638+ # int methods
639+ __abs__ = wrap_method (int .__abs__ )
640+ __add__ = wrap_method (int .__add__ )
641+ __and__ = wrap_method (int .__and__ )
642+ __ceil__ = wrap_method (int .__ceil__ )
643+ __eq__ = int .__eq__
644+ __floor__ = wrap_method (int .__floor__ )
645+ __floordiv__ = wrap_method (int .__floordiv__ )
646+ __invert__ = wrap_method (int .__invert__ )
647+ __le__ = int .__le__
648+ __lshift__ = wrap_method (int .__lshift__ )
649+ __lt__ = int .__lt__
650+ __mod__ = wrap_method (int .__mod__ )
651+ __mul__ = wrap_method (int .__mul__ )
652+ __neg__ = wrap_method (int .__neg__ )
653+ __or__ = wrap_method (int .__or__ )
654+ __pos__ = wrap_method (int .__pos__ )
655+ __pow__ = wrap_method (int .__pow__ )
656+ __radd__ = wrap_method (int .__radd__ )
657+ __rand__ = wrap_method (int .__rand__ )
658+ __rfloordiv__ = wrap_method (int .__rfloordiv__ )
659+ __rlshift__ = wrap_method (int .__rlshift__ )
660+ __rmod__ = wrap_method (int .__rmod__ )
661+ __rmul__ = wrap_method (int .__rmul__ )
662+ __ror__ = wrap_method (int .__ror__ )
663+ __round__ = wrap_method (int .__round__ )
664+ __rpow__ = wrap_method (int .__rpow__ )
665+ __rrshift__ = wrap_method (int .__rrshift__ )
666+ __rshift__ = wrap_method (int .__rshift__ )
667+ __rtruediv__ = wrap_method (int .__rtruediv__ )
668+ __rxor__ = wrap_method (int .__rxor__ )
669+ __truediv__ = wrap_method (int .__truediv__ )
670+ __trunc__ = wrap_method (int .__trunc__ )
671+ __xor__ = wrap_method (int .__xor__ )
672+
673+
674+ class Float (Item , _CustomFloat ):
678675 """
679676 A float literal.
680677 """
681678
682- def __new__ (cls , value : float , trivia : Trivia , raw : str ) -> Integer :
683- return super () .__new__ (cls , value )
679+ def __new__ (cls , value : float , trivia : Trivia , raw : str ) -> Float :
680+ return float .__new__ (cls , value )
684681
685- def __init__ (self , _ : float , trivia : Trivia , raw : str ) -> None :
682+ def __init__ (self , value : float , trivia : Trivia , raw : str ) -> None :
686683 super ().__init__ (trivia )
687-
684+ self . _original = value
688685 self ._raw = raw
689686 self ._sign = False
690687
691688 if re .match (r"^[+\-].+$" , raw ):
692689 self ._sign = True
693690
694691 def unwrap (self ) -> float :
695- return float (self )
692+ return self ._original
693+
694+ __float__ = unwrap
696695
697696 @property
698697 def discriminant (self ) -> int :
@@ -706,32 +705,6 @@ def value(self) -> float:
706705 def as_string (self ) -> str :
707706 return self ._raw
708707
709- def __add__ (self , other ):
710- result = super ().__add__ (other )
711-
712- return self ._new (result )
713-
714- def __radd__ (self , other ):
715- result = super ().__radd__ (other )
716-
717- if isinstance (other , Float ):
718- return self ._new (result )
719-
720- return result
721-
722- def __sub__ (self , other ):
723- result = super ().__sub__ (other )
724-
725- return self ._new (result )
726-
727- def __rsub__ (self , other ):
728- result = super ().__rsub__ (other )
729-
730- if isinstance (other , Float ):
731- return self ._new (result )
732-
733- return result
734-
735708 def _new (self , result ):
736709 raw = str (result )
737710
@@ -744,6 +717,35 @@ def _new(self, result):
744717 def _getstate (self , protocol = 3 ):
745718 return float (self ), self ._trivia , self ._raw
746719
720+ # float methods
721+ __abs__ = wrap_method (float .__abs__ )
722+ __add__ = wrap_method (float .__add__ )
723+ __eq__ = float .__eq__
724+ __floordiv__ = wrap_method (float .__floordiv__ )
725+ __le__ = float .__le__
726+ __lt__ = float .__lt__
727+ __mod__ = wrap_method (float .__mod__ )
728+ __mul__ = wrap_method (float .__mul__ )
729+ __neg__ = wrap_method (float .__neg__ )
730+ __pos__ = wrap_method (float .__pos__ )
731+ __pow__ = wrap_method (float .__pow__ )
732+ __radd__ = wrap_method (float .__radd__ )
733+ __rfloordiv__ = wrap_method (float .__rfloordiv__ )
734+ __rmod__ = wrap_method (float .__rmod__ )
735+ __rmul__ = wrap_method (float .__rmul__ )
736+ __round__ = wrap_method (float .__round__ )
737+ __rpow__ = wrap_method (float .__rpow__ )
738+ __rtruediv__ = wrap_method (float .__rtruediv__ )
739+ __truediv__ = wrap_method (float .__truediv__ )
740+ __trunc__ = float .__trunc__
741+
742+ if sys .version_info >= (3 , 9 ):
743+ __ceil__ = float .__ceil__
744+ __floor__ = float .__floor__
745+ else :
746+ __ceil__ = math .ceil
747+ __floor__ = math .floor
748+
747749
748750class Bool (Item ):
749751 """
@@ -1410,9 +1412,6 @@ def _getstate(self, protocol=3):
14101412 return list (self ._iter_items ()), self ._trivia , self ._multiline
14111413
14121414
1413- AT = TypeVar ("AT" , bound = "AbstractTable" )
1414-
1415-
14161415class AbstractTable (Item , _CustomDict ):
14171416 """Common behaviour of both :class:`Table` and :class:`InlineTable`"""
14181417
@@ -1452,11 +1451,11 @@ def append(self, key, value):
14521451 raise NotImplementedError
14531452
14541453 @overload
1455- def add (self : AT , value : Comment | Whitespace ) -> AT :
1454+ def add (self : AT , key : Comment | Whitespace ) -> AT :
14561455 ...
14571456
14581457 @overload
1459- def add (self : AT , key : Key | str , value : Any ) -> AT :
1458+ def add (self : AT , key : Key | str , value : Any = ... ) -> AT :
14601459 ...
14611460
14621461 def add (self , key , value = None ):
0 commit comments