Skip to content

Commit 76e1a75

Browse files
committed
feat: union alias in python 3.14
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
1 parent 9a21467 commit 76e1a75

File tree

5 files changed

+371
-89
lines changed

5 files changed

+371
-89
lines changed

docs/union_aliases.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ You can deactivate union aliases with `deactivate_union_aliases`:
145145
>>> import warnings
146146
>>> from plum import deactivate_union_aliases
147147

148-
>>> with warnings.catch_warnings(action="ignore"):
148+
>>> with warnings.catch_warnings():
149+
... warnings.simplefilter("ignore")
149150
... deactivate_union_aliases()
150151

151152
% skip: next "Result depends on NumPy version."

src/plum/_alias.py

Lines changed: 81 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,19 @@
3434

3535
import sys
3636
from functools import wraps
37-
from typing import TypeVar, Union, _type_repr, get_args
38-
from typing_extensions import deprecated
37+
from typing import Any, TypeVar, Union, _type_repr, get_args
38+
from typing_extensions import TypeAliasType, deprecated
3939

4040
UnionT = TypeVar("UnionT")
4141

42-
_ALIASED_UNIONS: list = []
42+
_union_type = type(Union[int, float]) # noqa: UP007
4343

44-
if sys.version_info < (3, 14):
45-
_union_type = type(Union[int, float]) # noqa: UP007
44+
if sys.version_info < (3, 14): # pragma: specific no cover 3.14
4645
_original_repr = _union_type.__repr__
4746
_original_str = _union_type.__str__
4847

48+
_ALIASED_UNIONS: dict[tuple[Any, ...], str] = {}
49+
4950
@wraps(_original_repr)
5051
def _new_repr(self: object) -> str:
5152
"""Print a `typing.Union`, replacing all aliased unions by their aliased names.
@@ -60,7 +61,7 @@ def _new_repr(self: object) -> str:
6061
found_unions = []
6162
found_positions = []
6263
found_aliases = []
63-
for union, alias in reversed(_ALIASED_UNIONS):
64+
for union, alias in reversed(_ALIASED_UNIONS.items()):
6465
union_set = set(union)
6566
if union_set <= args_set:
6667
found = False
@@ -77,40 +78,30 @@ def _new_repr(self: object) -> str:
7778
"Could not identify union. This should never happen."
7879
)
7980

80-
# Delete any unions that are contained in strictly bigger unions. We check for
81-
# strictly inequality because any union includes itself.
81+
# Delete any unions that are contained in strictly bigger unions. We
82+
# check for strictly inequality because any union includes itself.
8283
for i in range(len(found_unions) - 1, -1, -1):
83-
for union in found_unions:
84-
if found_unions[i] < union:
84+
for union_ in found_unions:
85+
if found_unions[i] < set(union_):
8586
del found_unions[i]
8687
del found_positions[i]
8788
del found_aliases[i]
8889
break
8990

9091
# Create a set with all arguments of all found unions.
91-
found_args = set()
92-
for union in found_unions:
93-
found_args |= union
94-
95-
# Insert the aliases right before the first found argument. When we insert an
96-
# element, the positions of following insertions need to be appropriately
97-
# incremented.
98-
args = list(args)
99-
# Sort by insertion position to ensure that all following insertions are
100-
# at higher indices. This makes the bookkeeping simple.
101-
for delta, (i, alias) in enumerate(
102-
sorted(
103-
zip(found_positions, found_aliases, strict=False), key=lambda x: x[0]
104-
)
105-
):
106-
args.insert(i + delta, alias)
92+
found_args = set().union(*found_unions) if found_unions else set()
93+
94+
# Build a mapping from original position to aliases to insert before it.
95+
inserts: dict[int, list[str]] = {}
96+
for pos, alias in zip(found_positions, found_aliases, strict=False):
97+
inserts.setdefault(pos, []).append(alias)
98+
# Interleave aliases at the appropriate positions.
99+
args = tuple(
100+
v for i, arg in enumerate(args) for v in (*inserts.pop(i, []), arg)
101+
)
107102

108103
# Filter all elements of unions that are aliased.
109-
new_args = ()
110-
for arg in args:
111-
if arg not in found_args:
112-
new_args += (arg,)
113-
args = new_args
104+
args = tuple(arg for arg in args if arg not in found_args)
114105

115106
# Generate a string representation.
116107
args_repr = [a if isinstance(a, str) else _type_repr(a) for a in args]
@@ -140,8 +131,8 @@ def _new_str(self: object) -> str:
140131
def activate_union_aliases() -> None:
141132
"""When printing `typing.Union`s, replace aliased unions by the aliased names.
142133
This monkey patches `__repr__` and `__str__` for `typing.Union`."""
143-
_union_type.__repr__ = _new_repr
144-
_union_type.__str__ = _new_str
134+
_union_type.__repr__ = _new_repr # type: ignore[method-assign]
135+
_union_type.__str__ = _new_str # type: ignore[method-assign]
145136

146137
@deprecated(
147138
"`deactivate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501
@@ -150,13 +141,9 @@ def activate_union_aliases() -> None:
150141
def deactivate_union_aliases() -> None:
151142
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
152143
and `__str__` for `typing.Union`."""
153-
_union_type.__repr__ = _original_repr
154-
_union_type.__str__ = _original_str
144+
_union_type.__repr__ = _original_repr # type: ignore[method-assign]
145+
_union_type.__str__ = _original_str # type: ignore[method-assign]
155146

156-
@deprecated(
157-
"`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501
158-
stacklevel=2,
159-
)
160147
def set_union_alias(union: UnionT, alias: str) -> UnionT:
161148
"""Change how a `typing.Union` is printed. This does not modify `union`.
162149
@@ -168,7 +155,7 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT:
168155
type or type hint: `union`.
169156
"""
170157
args = get_args(union) if isinstance(union, _union_type) else (union,)
171-
for existing_union, existing_alias in _ALIASED_UNIONS:
158+
for existing_union, existing_alias in _ALIASED_UNIONS.items():
172159
if set(existing_union) == set(args) and alias != existing_alias:
173160
if isinstance(union, _union_type):
174161
union_str = _original_str(union)
@@ -177,11 +164,11 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT:
177164
raise RuntimeError(
178165
f"`{union_str}` already has alias `{existing_alias}`."
179166
)
180-
_ALIASED_UNIONS.append((args, alias))
167+
_ALIASED_UNIONS[args] = alias
181168
return union
182169

183-
184-
else:
170+
else: # pragma: specific no cover 3.13 3.12 3.11 3.10
171+
_ALIASED_UNIONS: dict[tuple[Any, ...], TypeAliasType] = {}
185172

186173
@deprecated(
187174
"`activate_union_aliases` is deprecated and will be removed in a future version.", # noqa: E501
@@ -200,23 +187,60 @@ def activate_union_aliases() -> None:
200187
def deactivate_union_aliases() -> None:
201188
"""Undo what :func:`.alias.activate` did. This restores the original `__repr__`
202189
and `__str__` for `typing.Union`."""
203-
if sys.version_info < (3, 14):
204-
_union_type.__repr__ = _original_repr
205-
_union_type.__str__ = _original_str
206190

207-
@deprecated(
208-
"`set_union_alias` is deprecated and will be removed in a future version.", # noqa: E501
209-
category=RuntimeWarning,
210-
stacklevel=2,
211-
)
212-
def set_union_alias(union: UnionT, alias: str) -> UnionT:
213-
"""Change how a `typing.Union` is printed. This does not modify `union`.
191+
def set_union_alias(union: UnionT, /, alias: str) -> UnionT:
192+
"""Register a union alias for use in plum's dispatch system.
193+
194+
When used with plum's dispatch system, the union will be automatically
195+
transformed into a `TypeAliasType` during signature extraction, allowing
196+
dispatch to key off the alias name instead of the union structure.
214197
215198
Args:
216-
union (type or type hint): A union.
217-
alias (str): How to print `union`.
199+
union (type or type hint): A union type or a single type.
200+
alias (str): Alias name for the union.
218201
219-
Returns:
220-
type or type hint: `union`.
221202
"""
203+
# Handle both union types and single types, matching < 3.14 behaviour.
204+
args = get_args(union) if isinstance(union, _union_type) else (union,)
205+
206+
# Check for conflicting aliases
207+
for existing_union, existing_alias in _ALIASED_UNIONS.items():
208+
if set(existing_union) == set(args) and alias != repr(existing_alias):
209+
union_str = repr(union)
210+
raise RuntimeError(
211+
f"`{union_str}` already has alias `{existing_alias!r}`."
212+
)
213+
214+
new_alias = TypeAliasType(alias, union, type_params=()) # type: ignore[misc]
215+
216+
_ALIASED_UNIONS[args] = new_alias
217+
222218
return union
219+
220+
221+
def _transform_union_alias(x: object, /) -> object:
222+
"""Transform a Union type hint to a TypeAliasType if it's registered in the alias
223+
registry. This is used by plum's dispatch machinery to use aliased names for unions.
224+
225+
Args:
226+
x (type or type hint): Type hint, potentially a Union.
227+
228+
Returns:
229+
type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns
230+
the TypeAliasType. Otherwise returns `x` unchanged.
231+
"""
232+
# TypeAliasType instances are already transformed, return as-is
233+
if isinstance(x, TypeAliasType):
234+
return x
235+
236+
# Get the union args to check if it's registered
237+
args = get_args(x) if isinstance(x, _union_type) else None
238+
if args:
239+
args_set = set(args)
240+
# Look for a matching alias in the registry
241+
for union_args, type_alias in _ALIASED_UNIONS.items():
242+
if set(union_args) == args_set:
243+
return type_alias
244+
245+
# Not a union or not aliased, return as-is
246+
return x

src/plum/repr.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"repr_pyfunction",
66
"rich_repr",
77
]
8-
98
import inspect
109
import os
1110
import sys
@@ -14,12 +13,15 @@
1413
from collections.abc import Callable, Iterable
1514
from functools import partial
1615
from typing import Any, TypeVar, overload
16+
from typing_extensions import TypeAliasType
1717

1818
import rich
1919
from rich.color import Color
2020
from rich.style import Style
2121
from rich.text import Text
2222

23+
from ._alias import _transform_union_alias
24+
2325
T = TypeVar("T")
2426

2527
path_style = Style(color=Color.from_ansi(7))
@@ -41,6 +43,9 @@ def repr_type(x: object, /) -> Text:
4143
Returns:
4244
:class:`rich.Text`: Representation.
4345
"""
46+
# Apply union aliasing if `x` is a union. This allows us to have the correct
47+
# syntax highlighting for aliased unions.
48+
x = _transform_union_alias(x)
4449

4550
if isinstance(x, type):
4651
if x.__module__ in ["builtins", "typing", "typing_extensions"]:
@@ -60,14 +65,20 @@ def repr_short(x: object, /) -> str:
6065
"""Representation as a string, but in shorter form. This just calls
6166
:func:`typing._type_repr`.
6267
68+
If the type is a union registered in plum's alias registry, the alias name
69+
is used instead.
70+
6371
Args:
6472
x (object): Object.
6573
6674
Returns:
6775
str: Shorter representation of `x`.
6876
"""
69-
# :func:`typing._type_repr` is an internal function, but it should be available in
70-
# Python versions 3.9 through 3.13.
77+
if isinstance(transformed := _transform_union_alias(x), TypeAliasType):
78+
# It's an aliased union — use the alias name
79+
return str(transformed.__name__)
80+
# :func:`typing._type_repr` is an internal function, but it should be
81+
# available in Python versions 3.9 through 3.14.
7182
return typing._type_repr(x)
7283

7384

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
"""Fixtures for testing."""
2+
3+
from unittest.mock import patch
4+
15
import pytest
26

37
import plum
48
from plum._promotion import _convert, _promotion_rule
59

610

11+
@pytest.fixture(autouse=True)
12+
def _clean_union_aliases():
13+
"""Give each test its own empty alias registry, restored automatically."""
14+
from plum._alias import _ALIASED_UNIONS
15+
16+
with patch.dict(_ALIASED_UNIONS, clear=True):
17+
yield
18+
19+
720
@pytest.fixture
821
def dispatch() -> plum.Dispatcher:
922
"""Provide a fresh Dispatcher for testing."""

0 commit comments

Comments
 (0)