|
26 | 26 | parsing how unions print. |
27 | 27 | """ |
28 | 28 |
|
| 29 | +import sys |
29 | 30 | from functools import wraps |
30 | 31 | from typing import TypeVar, Union, _type_repr, get_args |
31 | 32 | from typing_extensions import assert_never |
32 | 33 |
|
| 34 | +from typing_extensions import deprecated |
| 35 | + |
33 | 36 | __all__ = ["activate_union_aliases", "deactivate_union_aliases", "set_union_alias"] |
34 | 37 |
|
35 | 38 | UnionT = TypeVar("UnionT") |
36 | 39 |
|
37 | | -_union_type = type(Union[int, float]) # noqa: UP007 |
38 | | -_original_repr = _union_type.__repr__ |
39 | | -_original_str = _union_type.__str__ |
40 | | - |
41 | | - |
42 | | -@wraps(_original_repr) |
43 | | -def _new_repr(self: object) -> str: |
44 | | - """Print a `typing.Union`, replacing all aliased unions by their aliased names. |
45 | | -
|
46 | | - Returns: |
47 | | - str: Representation of a `typing.Union` taking into account union aliases. |
48 | | - """ |
49 | | - args = get_args(self) |
50 | | - args_set = set(args) |
51 | | - |
52 | | - # Find all aliased unions contained in this union. |
53 | | - found_unions = [] |
54 | | - found_positions = [] |
55 | | - found_aliases = [] |
56 | | - for union, alias in reversed(_ALIASED_UNIONS): |
57 | | - union_set = set(union) |
58 | | - if union_set <= args_set: |
59 | | - for i, arg in enumerate(args): |
60 | | - if arg in union_set: |
61 | | - found_unions.append(union_set) |
62 | | - found_positions.append(i) |
63 | | - found_aliases.append(alias) |
| 40 | + |
| 41 | +if sys.version_info < (3, 14): |
| 42 | + _union_type = type(Union[int, float]) |
| 43 | + _original_repr = _union_type.__repr__ |
| 44 | + _original_str = _union_type.__str__ |
| 45 | + |
| 46 | + @wraps(_original_repr) |
| 47 | + def _new_repr(self: object) -> str: |
| 48 | + """Print a `typing.Union`, replacing all aliased unions by their aliased names. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + str: Representation of a `typing.Union` taking into account union aliases. |
| 52 | + """ |
| 53 | + args = get_args(self) |
| 54 | + args_set = set(args) |
| 55 | + |
| 56 | + # Find all aliased unions contained in this union. |
| 57 | + found_unions = [] |
| 58 | + found_positions = [] |
| 59 | + found_aliases = [] |
| 60 | + for union, alias in reversed(_ALIASED_UNIONS): |
| 61 | + union_set = set(union) |
| 62 | + if union_set <= args_set: |
| 63 | + found = False |
| 64 | + for i, arg in enumerate(args): |
| 65 | + if arg in union_set: |
| 66 | + found_unions.append(union_set) |
| 67 | + found_positions.append(i) |
| 68 | + found_aliases.append(alias) |
| 69 | + found = True |
| 70 | + break |
| 71 | + if not found: # pragma: no cover |
| 72 | + # This branch should never be reached. |
| 73 | + raise AssertionError( |
| 74 | + "Could not identify union. This should never happen." |
| 75 | + ) |
| 76 | + |
| 77 | + # Delete any unions that are contained in strictly bigger unions. We check for |
| 78 | + # strictly inequality because any union includes itself. |
| 79 | + for i in range(len(found_unions) - 1, -1, -1): |
| 80 | + for union in found_unions: |
| 81 | + if found_unions[i] < union: |
| 82 | + del found_unions[i] |
| 83 | + del found_positions[i] |
| 84 | + del found_aliases[i] |
64 | 85 | break |
65 | | - else: # pragma: no cover |
66 | | - assert_never(union) |
67 | 86 |
|
68 | | - # Delete any unions that are contained in strictly bigger unions. We check |
69 | | - # for strictly inequality because any union includes itself. |
70 | | - for i in range(len(found_unions) - 1, -1, -1): |
| 87 | + # Create a set with all arguments of all found unions. |
| 88 | + found_args = set() |
71 | 89 | for union in found_unions: |
72 | | - if found_unions[i] < union: |
73 | | - del found_unions[i] |
74 | | - del found_positions[i] |
75 | | - del found_aliases[i] |
76 | | - break |
77 | | - |
78 | | - # Create a set with all arguments of all found unions. |
79 | | - found_args = set() |
80 | | - for union in found_unions: |
81 | | - found_args |= union |
82 | | - |
83 | | - # Insert the aliases right before the first found argument. When we insert |
84 | | - # an element, the positions of following insertions need to be appropriately |
85 | | - # incremented. |
86 | | - args = list(args) |
87 | | - # Sort by insertion position to ensure that all following insertions are at |
88 | | - # higher indices. This makes the bookkeeping simple. |
89 | | - for delta, (i, alias) in enumerate( |
90 | | - sorted(zip(found_positions, found_aliases, strict=True), key=lambda x: x[0]) |
91 | | - ): |
92 | | - args.insert(i + delta, alias) |
93 | | - |
94 | | - # Filter all elements of unions that are aliased. |
95 | | - new_args = () |
96 | | - for arg in args: |
97 | | - if arg not in found_args: |
98 | | - new_args += (arg,) |
99 | | - args = new_args |
100 | | - |
101 | | - # Generate a string representation. |
102 | | - args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] |
103 | | - # Like `typing` does, print `Optional` whenever possible. |
104 | | - if len(args) == 2: |
105 | | - if args[0] is type(None): # noqa: E721 |
106 | | - return f"typing.Optional[{args_repr[1]}]" |
107 | | - elif args[1] is type(None): # noqa: E721 |
108 | | - return f"typing.Optional[{args_repr[0]}]" |
109 | | - # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but |
110 | | - # this might break code that parses how unions print. |
111 | | - return "typing.Union[" + ", ".join(args_repr) + "]" |
112 | | - |
113 | | - |
114 | | -@wraps(_original_str) |
115 | | -def _new_str(self: object) -> str: |
116 | | - """Does the same as :func:`_new_repr`. |
117 | | -
|
118 | | - Returns: |
119 | | - str: Representation of the `typing.Union` taking into account union aliases. |
120 | | - """ |
121 | | - return _new_repr(self) |
122 | | - |
123 | | - |
124 | | -def activate_union_aliases() -> None: |
125 | | - """When printing `typing.Union`s, replace all aliased unions by the aliased names. |
126 | | - This monkey patches `__repr__` and `__str__` for `typing.Union`.""" |
127 | | - _union_type.__repr__ = _new_repr |
128 | | - _union_type.__str__ = _new_str |
129 | | - |
130 | | - |
131 | | -def deactivate_union_aliases() -> None: |
132 | | - """Undo what :func:`.alias.activate` did. This restores the original `__repr__` |
133 | | - and `__str__` for `typing.Union`.""" |
134 | | - _union_type.__repr__ = _original_repr |
135 | | - _union_type.__str__ = _original_str |
136 | | - |
137 | | - |
138 | | -_ALIASED_UNIONS: list = [] |
139 | | - |
140 | | - |
141 | | -def set_union_alias(union: UnionT, alias: str) -> UnionT: |
142 | | - """Change how a `typing.Union` is printed. This does not modify `union`. |
143 | | -
|
144 | | - Args: |
145 | | - union (type or type hint): A union. |
146 | | - alias (str): How to print `union`. |
147 | | -
|
148 | | - Returns: |
149 | | - type or type hint: `union`. |
150 | | - """ |
151 | | - args = get_args(union) if isinstance(union, _union_type) else (union,) |
152 | | - for existing_union, existing_alias in _ALIASED_UNIONS: |
153 | | - if set(existing_union) == set(args) and alias != existing_alias: |
154 | | - if isinstance(union, _union_type): |
155 | | - union_str = _original_str(union) |
156 | | - else: |
157 | | - union_str = repr(union) |
158 | | - raise RuntimeError(f"`{union_str}` already has alias `{existing_alias}`.") |
159 | | - _ALIASED_UNIONS.append((args, alias)) |
160 | | - return union |
| 90 | + found_args |= union |
| 91 | + |
| 92 | + # Insert the aliases right before the first found argument. When we insert an |
| 93 | + # element, the positions of following insertions need to be appropriately |
| 94 | + # incremented. |
| 95 | + args = list(args) |
| 96 | + # Sort by insertion position to ensure that all following insertions are |
| 97 | + # at higher indices. This makes the bookkeeping simple. |
| 98 | + for delta, (i, alias) in enumerate( |
| 99 | + sorted(zip(found_positions, found_aliases), key=lambda x: x[0]) |
| 100 | + ): |
| 101 | + args.insert(i + delta, alias) |
| 102 | + |
| 103 | + # Filter all elements of unions that are aliased. |
| 104 | + new_args = () |
| 105 | + for arg in args: |
| 106 | + if arg not in found_args: |
| 107 | + new_args += (arg,) |
| 108 | + args = new_args |
| 109 | + |
| 110 | + # Generate a string representation. |
| 111 | + args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args] |
| 112 | + # Like `typing` does, print `Optional` whenever possible. |
| 113 | + if len(args) == 2: |
| 114 | + if args[0] is type(None): # noqa: E721 |
| 115 | + return f"typing.Optional[{args_repr[1]}]" |
| 116 | + elif args[1] is type(None): # noqa: E721 |
| 117 | + return f"typing.Optional[{args_repr[0]}]" |
| 118 | + # We would like to just print `args_repr[0]` whenever `len(args) == 1`, but |
| 119 | + # this might break code that parses how unions print. |
| 120 | + return "typing.Union[" + ", ".join(args_repr) + "]" |
| 121 | + |
| 122 | + @wraps(_original_str) |
| 123 | + def _new_str(self: object) -> str: |
| 124 | + """Does the same as :func:`_new_repr`. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + str: Representation of the `typing.Union` taking into account union aliases. |
| 128 | + """ |
| 129 | + return _new_repr(self) |
| 130 | + |
| 131 | + def activate_union_aliases() -> None: |
| 132 | + """When printing `typing.Union`s, replace aliased unions by the aliased names. |
| 133 | + This monkey patches `__repr__` and `__str__` for `typing.Union`.""" |
| 134 | + _union_type.__repr__ = _new_repr |
| 135 | + _union_type.__str__ = _new_str |
| 136 | + |
| 137 | + def deactivate_union_aliases() -> None: |
| 138 | + """Undo what :func:`.alias.activate` did. This restores the original `__repr__` |
| 139 | + and `__str__` for `typing.Union`.""" |
| 140 | + _union_type.__repr__ = _original_repr |
| 141 | + _union_type.__str__ = _original_str |
| 142 | + |
| 143 | + _ALIASED_UNIONS: list = [] |
| 144 | + |
| 145 | + def set_union_alias(union: UnionT, alias: str) -> UnionT: |
| 146 | + """Change how a `typing.Union` is printed. This does not modify `union`. |
| 147 | +
|
| 148 | + Args: |
| 149 | + union (type or type hint): A union. |
| 150 | + alias (str): How to print `union`. |
| 151 | +
|
| 152 | + Returns: |
| 153 | + type or type hint: `union`. |
| 154 | + """ |
| 155 | + if sys.version_info >= (3, 14): |
| 156 | + return union |
| 157 | + |
| 158 | + args = get_args(union) if isinstance(union, _union_type) else (union,) |
| 159 | + for existing_union, existing_alias in _ALIASED_UNIONS: |
| 160 | + if set(existing_union) == set(args) and alias != existing_alias: |
| 161 | + if isinstance(union, _union_type): |
| 162 | + union_str = _original_str(union) |
| 163 | + else: |
| 164 | + union_str = repr(union) |
| 165 | + raise RuntimeError( |
| 166 | + f"`{union_str}` already has alias `{existing_alias}`." |
| 167 | + ) |
| 168 | + _ALIASED_UNIONS.append((args, alias)) |
| 169 | + return union |
| 170 | + |
| 171 | + |
| 172 | +else: |
| 173 | + |
| 174 | + @deprecated( |
| 175 | + "Plum's union aliasing is not supported on Python 3.14 and later.", |
| 176 | + category=RuntimeWarning, |
| 177 | + stacklevel=2, |
| 178 | + ) |
| 179 | + def activate_union_aliases() -> None: |
| 180 | + """When printing `typing.Union`s, replace aliased unions by the aliased names. |
| 181 | + This monkey patches `__repr__` and `__str__` for `typing.Union`.""" |
| 182 | + |
| 183 | + @deprecated( |
| 184 | + "Plum's union aliasing is not supported on Python 3.14 and later.", |
| 185 | + category=RuntimeWarning, |
| 186 | + stacklevel=2, |
| 187 | + ) |
| 188 | + def deactivate_union_aliases() -> None: |
| 189 | + """Undo what :func:`.alias.activate` did. This restores the original `__repr__` |
| 190 | + and `__str__` for `typing.Union`.""" |
| 191 | + if sys.version_info < (3, 14): |
| 192 | + _union_type.__repr__ = _original_repr |
| 193 | + _union_type.__str__ = _original_str |
| 194 | + |
| 195 | + @deprecated( |
| 196 | + "Plum's union aliasing is not supported on Python 3.14 and later.", |
| 197 | + category=RuntimeWarning, |
| 198 | + stacklevel=2, |
| 199 | + ) |
| 200 | + def set_union_alias(union: UnionT, alias: str) -> UnionT: |
| 201 | + """Change how a `typing.Union` is printed. This does not modify `union`. |
| 202 | +
|
| 203 | + Args: |
| 204 | + union (type or type hint): A union. |
| 205 | + alias (str): How to print `union`. |
| 206 | +
|
| 207 | + Returns: |
| 208 | + type or type hint: `union`. |
| 209 | + """ |
| 210 | + return union |
0 commit comments