Skip to content

Commit a7f501c

Browse files
committed
feat: deprecate union aliasing on Python 3.14 and later
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent bc86b5e commit a7f501c

File tree

1 file changed

+171
-121
lines changed

1 file changed

+171
-121
lines changed

plum/alias.py

Lines changed: 171 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -26,135 +26,185 @@
2626
parsing how unions print.
2727
"""
2828

29+
import sys
2930
from functools import wraps
3031
from typing import TypeVar, Union, _type_repr, get_args
3132
from typing_extensions import assert_never
3233

34+
from typing_extensions import deprecated
35+
3336
__all__ = ["activate_union_aliases", "deactivate_union_aliases", "set_union_alias"]
3437

3538
UnionT = TypeVar("UnionT")
3639

37-
_union_type = type(Union[int, float]) # noqa: UP007
38-
_original_repr = _union_type.__repr__
39-
_original_str = _union_type.__str__
40-
41-
42-
@wraps(_original_repr)
43-
def _new_repr(self: object) -> str:
44-
"""Print a `typing.Union`, replacing all aliased unions by their aliased names.
45-
46-
Returns:
47-
str: Representation of a `typing.Union` taking into account union aliases.
48-
"""
49-
args = get_args(self)
50-
args_set = set(args)
51-
52-
# Find all aliased unions contained in this union.
53-
found_unions = []
54-
found_positions = []
55-
found_aliases = []
56-
for union, alias in reversed(_ALIASED_UNIONS):
57-
union_set = set(union)
58-
if union_set <= args_set:
59-
for i, arg in enumerate(args):
60-
if arg in union_set:
61-
found_unions.append(union_set)
62-
found_positions.append(i)
63-
found_aliases.append(alias)
40+
41+
if sys.version_info < (3, 14):
42+
_union_type = type(Union[int, float])
43+
_original_repr = _union_type.__repr__
44+
_original_str = _union_type.__str__
45+
46+
@wraps(_original_repr)
47+
def _new_repr(self: object) -> str:
48+
"""Print a `typing.Union`, replacing all aliased unions by their aliased names.
49+
50+
Returns:
51+
str: Representation of a `typing.Union` taking into account union aliases.
52+
"""
53+
args = get_args(self)
54+
args_set = set(args)
55+
56+
# Find all aliased unions contained in this union.
57+
found_unions = []
58+
found_positions = []
59+
found_aliases = []
60+
for union, alias in reversed(_ALIASED_UNIONS):
61+
union_set = set(union)
62+
if union_set <= args_set:
63+
found = False
64+
for i, arg in enumerate(args):
65+
if arg in union_set:
66+
found_unions.append(union_set)
67+
found_positions.append(i)
68+
found_aliases.append(alias)
69+
found = True
70+
break
71+
if not found: # pragma: no cover
72+
# This branch should never be reached.
73+
raise AssertionError(
74+
"Could not identify union. This should never happen."
75+
)
76+
77+
# Delete any unions that are contained in strictly bigger unions. We check for
78+
# strictly inequality because any union includes itself.
79+
for i in range(len(found_unions) - 1, -1, -1):
80+
for union in found_unions:
81+
if found_unions[i] < union:
82+
del found_unions[i]
83+
del found_positions[i]
84+
del found_aliases[i]
6485
break
65-
else: # pragma: no cover
66-
assert_never(union)
6786

68-
# Delete any unions that are contained in strictly bigger unions. We check
69-
# for strictly inequality because any union includes itself.
70-
for i in range(len(found_unions) - 1, -1, -1):
87+
# Create a set with all arguments of all found unions.
88+
found_args = set()
7189
for union in found_unions:
72-
if found_unions[i] < union:
73-
del found_unions[i]
74-
del found_positions[i]
75-
del found_aliases[i]
76-
break
77-
78-
# Create a set with all arguments of all found unions.
79-
found_args = set()
80-
for union in found_unions:
81-
found_args |= union
82-
83-
# Insert the aliases right before the first found argument. When we insert
84-
# an element, the positions of following insertions need to be appropriately
85-
# incremented.
86-
args = list(args)
87-
# Sort by insertion position to ensure that all following insertions are at
88-
# higher indices. This makes the bookkeeping simple.
89-
for delta, (i, alias) in enumerate(
90-
sorted(zip(found_positions, found_aliases, strict=True), key=lambda x: x[0])
91-
):
92-
args.insert(i + delta, alias)
93-
94-
# Filter all elements of unions that are aliased.
95-
new_args = ()
96-
for arg in args:
97-
if arg not in found_args:
98-
new_args += (arg,)
99-
args = new_args
100-
101-
# Generate a string representation.
102-
args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args]
103-
# Like `typing` does, print `Optional` whenever possible.
104-
if len(args) == 2:
105-
if args[0] is type(None): # noqa: E721
106-
return f"typing.Optional[{args_repr[1]}]"
107-
elif args[1] is type(None): # noqa: E721
108-
return f"typing.Optional[{args_repr[0]}]"
109-
# We would like to just print `args_repr[0]` whenever `len(args) == 1`, but
110-
# this might break code that parses how unions print.
111-
return "typing.Union[" + ", ".join(args_repr) + "]"
112-
113-
114-
@wraps(_original_str)
115-
def _new_str(self: object) -> str:
116-
"""Does the same as :func:`_new_repr`.
117-
118-
Returns:
119-
str: Representation of the `typing.Union` taking into account union aliases.
120-
"""
121-
return _new_repr(self)
122-
123-
124-
def activate_union_aliases() -> None:
125-
"""When printing `typing.Union`s, replace all aliased unions by the aliased names.
126-
This monkey patches `__repr__` and `__str__` for `typing.Union`."""
127-
_union_type.__repr__ = _new_repr
128-
_union_type.__str__ = _new_str
129-
130-
131-
def deactivate_union_aliases() -> None:
132-
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
133-
and `__str__` for `typing.Union`."""
134-
_union_type.__repr__ = _original_repr
135-
_union_type.__str__ = _original_str
136-
137-
138-
_ALIASED_UNIONS: list = []
139-
140-
141-
def set_union_alias(union: UnionT, alias: str) -> UnionT:
142-
"""Change how a `typing.Union` is printed. This does not modify `union`.
143-
144-
Args:
145-
union (type or type hint): A union.
146-
alias (str): How to print `union`.
147-
148-
Returns:
149-
type or type hint: `union`.
150-
"""
151-
args = get_args(union) if isinstance(union, _union_type) else (union,)
152-
for existing_union, existing_alias in _ALIASED_UNIONS:
153-
if set(existing_union) == set(args) and alias != existing_alias:
154-
if isinstance(union, _union_type):
155-
union_str = _original_str(union)
156-
else:
157-
union_str = repr(union)
158-
raise RuntimeError(f"`{union_str}` already has alias `{existing_alias}`.")
159-
_ALIASED_UNIONS.append((args, alias))
160-
return union
90+
found_args |= union
91+
92+
# Insert the aliases right before the first found argument. When we insert an
93+
# element, the positions of following insertions need to be appropriately
94+
# incremented.
95+
args = list(args)
96+
# Sort by insertion position to ensure that all following insertions are
97+
# at higher indices. This makes the bookkeeping simple.
98+
for delta, (i, alias) in enumerate(
99+
sorted(zip(found_positions, found_aliases), key=lambda x: x[0])
100+
):
101+
args.insert(i + delta, alias)
102+
103+
# Filter all elements of unions that are aliased.
104+
new_args = ()
105+
for arg in args:
106+
if arg not in found_args:
107+
new_args += (arg,)
108+
args = new_args
109+
110+
# Generate a string representation.
111+
args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args]
112+
# Like `typing` does, print `Optional` whenever possible.
113+
if len(args) == 2:
114+
if args[0] is type(None): # noqa: E721
115+
return f"typing.Optional[{args_repr[1]}]"
116+
elif args[1] is type(None): # noqa: E721
117+
return f"typing.Optional[{args_repr[0]}]"
118+
# We would like to just print `args_repr[0]` whenever `len(args) == 1`, but
119+
# this might break code that parses how unions print.
120+
return "typing.Union[" + ", ".join(args_repr) + "]"
121+
122+
@wraps(_original_str)
123+
def _new_str(self: object) -> str:
124+
"""Does the same as :func:`_new_repr`.
125+
126+
Returns:
127+
str: Representation of the `typing.Union` taking into account union aliases.
128+
"""
129+
return _new_repr(self)
130+
131+
def activate_union_aliases() -> None:
132+
"""When printing `typing.Union`s, replace aliased unions by the aliased names.
133+
This monkey patches `__repr__` and `__str__` for `typing.Union`."""
134+
_union_type.__repr__ = _new_repr
135+
_union_type.__str__ = _new_str
136+
137+
def deactivate_union_aliases() -> None:
138+
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
139+
and `__str__` for `typing.Union`."""
140+
_union_type.__repr__ = _original_repr
141+
_union_type.__str__ = _original_str
142+
143+
_ALIASED_UNIONS: list = []
144+
145+
def set_union_alias(union: UnionT, alias: str) -> UnionT:
146+
"""Change how a `typing.Union` is printed. This does not modify `union`.
147+
148+
Args:
149+
union (type or type hint): A union.
150+
alias (str): How to print `union`.
151+
152+
Returns:
153+
type or type hint: `union`.
154+
"""
155+
if sys.version_info >= (3, 14):
156+
return union
157+
158+
args = get_args(union) if isinstance(union, _union_type) else (union,)
159+
for existing_union, existing_alias in _ALIASED_UNIONS:
160+
if set(existing_union) == set(args) and alias != existing_alias:
161+
if isinstance(union, _union_type):
162+
union_str = _original_str(union)
163+
else:
164+
union_str = repr(union)
165+
raise RuntimeError(
166+
f"`{union_str}` already has alias `{existing_alias}`."
167+
)
168+
_ALIASED_UNIONS.append((args, alias))
169+
return union
170+
171+
172+
else:
173+
174+
@deprecated(
175+
"Plum's union aliasing is not supported on Python 3.14 and later.",
176+
category=RuntimeWarning,
177+
stacklevel=2,
178+
)
179+
def activate_union_aliases() -> None:
180+
"""When printing `typing.Union`s, replace aliased unions by the aliased names.
181+
This monkey patches `__repr__` and `__str__` for `typing.Union`."""
182+
183+
@deprecated(
184+
"Plum's union aliasing is not supported on Python 3.14 and later.",
185+
category=RuntimeWarning,
186+
stacklevel=2,
187+
)
188+
def deactivate_union_aliases() -> None:
189+
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
190+
and `__str__` for `typing.Union`."""
191+
if sys.version_info < (3, 14):
192+
_union_type.__repr__ = _original_repr
193+
_union_type.__str__ = _original_str
194+
195+
@deprecated(
196+
"Plum's union aliasing is not supported on Python 3.14 and later.",
197+
category=RuntimeWarning,
198+
stacklevel=2,
199+
)
200+
def set_union_alias(union: UnionT, alias: str) -> UnionT:
201+
"""Change how a `typing.Union` is printed. This does not modify `union`.
202+
203+
Args:
204+
union (type or type hint): A union.
205+
alias (str): How to print `union`.
206+
207+
Returns:
208+
type or type hint: `union`.
209+
"""
210+
return union

0 commit comments

Comments
 (0)