3434
3535import sys
3636from functools import wraps
37- from typing import TypeVar , Union , _type_repr , get_args
38- from typing_extensions import deprecated
37+ from typing import Any , TypeVar , Union , _type_repr , get_args
38+ from typing_extensions import TypeAliasType , deprecated
3939
4040UnionT = TypeVar ("UnionT" )
4141
42- _ALIASED_UNIONS : list = []
42+ _union_type = type ( Union [ int , float ]) # noqa: UP007
4343
44- if sys .version_info < (3 , 14 ):
45- _union_type = type (Union [int , float ]) # noqa: UP007
44+ if sys .version_info < (3 , 14 ): # pragma: specific no cover 3.14
4645 _original_repr = _union_type .__repr__
4746 _original_str = _union_type .__str__
4847
48+ _ALIASED_UNIONS : dict [tuple [Any , ...], str ] = {}
49+
4950 @wraps (_original_repr )
5051 def _new_repr (self : object ) -> str :
5152 """Print a `typing.Union`, replacing all aliased unions by their aliased names.
@@ -60,7 +61,7 @@ def _new_repr(self: object) -> str:
6061 found_unions = []
6162 found_positions = []
6263 found_aliases = []
63- for union , alias in reversed (_ALIASED_UNIONS ):
64+ for union , alias in reversed (_ALIASED_UNIONS . items () ):
6465 union_set = set (union )
6566 if union_set <= args_set :
6667 found = False
@@ -77,40 +78,30 @@ def _new_repr(self: object) -> str:
7778 "Could not identify union. This should never happen."
7879 )
7980
80- # Delete any unions that are contained in strictly bigger unions. We check for
81- # strictly inequality because any union includes itself.
81+ # Delete any unions that are contained in strictly bigger unions. We
82+ # check for strictly inequality because any union includes itself.
8283 for i in range (len (found_unions ) - 1 , - 1 , - 1 ):
83- for union in found_unions :
84- if found_unions [i ] < union :
84+ for union_ in found_unions :
85+ if found_unions [i ] < set ( union_ ) :
8586 del found_unions [i ]
8687 del found_positions [i ]
8788 del found_aliases [i ]
8889 break
8990
9091 # Create a set with all arguments of all found unions.
91- found_args = set ()
92- for union in found_unions :
93- found_args |= union
94-
95- # Insert the aliases right before the first found argument. When we insert an
96- # element, the positions of following insertions need to be appropriately
97- # incremented.
98- args = list (args )
99- # Sort by insertion position to ensure that all following insertions are
100- # at higher indices. This makes the bookkeeping simple.
101- for delta , (i , alias ) in enumerate (
102- sorted (
103- zip (found_positions , found_aliases , strict = False ), key = lambda x : x [0 ]
104- )
105- ):
106- args .insert (i + delta , alias )
92+ found_args = set ().union (* found_unions ) if found_unions else set ()
93+
94+ # Build a mapping from original position to aliases to insert before it.
95+ inserts : dict [int , list [str ]] = {}
96+ for pos , alias in zip (found_positions , found_aliases , strict = False ):
97+ inserts .setdefault (pos , []).append (alias )
98+ # Interleave aliases at the appropriate positions.
99+ args = tuple (
100+ v for i , arg in enumerate (args ) for v in (* inserts .pop (i , []), arg )
101+ )
107102
108103 # Filter all elements of unions that are aliased.
109- new_args = ()
110- for arg in args :
111- if arg not in found_args :
112- new_args += (arg ,)
113- args = new_args
104+ args = tuple (arg for arg in args if arg not in found_args )
114105
115106 # Generate a string representation.
116107 args_repr = [a if isinstance (a , str ) else _type_repr (a ) for a in args ]
@@ -140,8 +131,8 @@ def _new_str(self: object) -> str:
140131 def activate_union_aliases () -> None :
141132 """When printing `typing.Union`s, replace aliased unions by the aliased names.
142133 This monkey patches `__repr__` and `__str__` for `typing.Union`."""
143- _union_type .__repr__ = _new_repr
144- _union_type .__str__ = _new_str
134+ _union_type .__repr__ = _new_repr # type: ignore[method-assign]
135+ _union_type .__str__ = _new_str # type: ignore[method-assign]
145136
146137 @deprecated (
147138 "`deactivate_union_aliases` is deprecated and will be removed in a future version." , # noqa: E501
@@ -150,13 +141,9 @@ def activate_union_aliases() -> None:
150141 def deactivate_union_aliases () -> None :
151142 """Undo what :func:`.alias.activate` did. This restores the original `__repr__`
152143 and `__str__` for `typing.Union`."""
153- _union_type .__repr__ = _original_repr
154- _union_type .__str__ = _original_str
144+ _union_type .__repr__ = _original_repr # type: ignore[method-assign]
145+ _union_type .__str__ = _original_str # type: ignore[method-assign]
155146
156- @deprecated (
157- "`set_union_alias` is deprecated and will be removed in a future version." , # noqa: E501
158- stacklevel = 2 ,
159- )
160147 def set_union_alias (union : UnionT , alias : str ) -> UnionT :
161148 """Change how a `typing.Union` is printed. This does not modify `union`.
162149
@@ -168,7 +155,7 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT:
168155 type or type hint: `union`.
169156 """
170157 args = get_args (union ) if isinstance (union , _union_type ) else (union ,)
171- for existing_union , existing_alias in _ALIASED_UNIONS :
158+ for existing_union , existing_alias in _ALIASED_UNIONS . items () :
172159 if set (existing_union ) == set (args ) and alias != existing_alias :
173160 if isinstance (union , _union_type ):
174161 union_str = _original_str (union )
@@ -177,11 +164,11 @@ def set_union_alias(union: UnionT, alias: str) -> UnionT:
177164 raise RuntimeError (
178165 f"`{ union_str } ` already has alias `{ existing_alias } `."
179166 )
180- _ALIASED_UNIONS . append (( args , alias ))
167+ _ALIASED_UNIONS [ args ] = alias
181168 return union
182169
183-
184- else :
170+ else : # pragma: specific no cover 3.13 3.12 3.11 3.10
171+ _ALIASED_UNIONS : dict [ tuple [ Any , ...], TypeAliasType ] = {}
185172
186173 @deprecated (
187174 "`activate_union_aliases` is deprecated and will be removed in a future version." , # noqa: E501
@@ -200,23 +187,60 @@ def activate_union_aliases() -> None:
200187 def deactivate_union_aliases () -> None :
201188 """Undo what :func:`.alias.activate` did. This restores the original `__repr__`
202189 and `__str__` for `typing.Union`."""
203- if sys .version_info < (3 , 14 ):
204- _union_type .__repr__ = _original_repr
205- _union_type .__str__ = _original_str
206190
207- @deprecated (
208- "`set_union_alias` is deprecated and will be removed in a future version." , # noqa: E501
209- category = RuntimeWarning ,
210- stacklevel = 2 ,
211- )
212- def set_union_alias (union : UnionT , alias : str ) -> UnionT :
213- """Change how a `typing.Union` is printed. This does not modify `union`.
191+ def set_union_alias (union : UnionT , / , alias : str ) -> UnionT :
192+ """Register a union alias for use in plum's dispatch system.
193+
194+ When used with plum's dispatch system, the union will be automatically
195+ transformed into a `TypeAliasType` during signature extraction, allowing
196+ dispatch to key off the alias name instead of the union structure.
214197
215198 Args:
216- union (type or type hint): A union.
217- alias (str): How to print ` union` .
199+ union (type or type hint): A union type or a single type .
200+ alias (str): Alias name for the union.
218201
219- Returns:
220- type or type hint: `union`.
221202 """
203+ # Handle both union types and single types, matching < 3.14 behaviour.
204+ args = get_args (union ) if isinstance (union , _union_type ) else (union ,)
205+
206+ # Check for conflicting aliases
207+ for existing_union , existing_alias in _ALIASED_UNIONS .items ():
208+ if set (existing_union ) == set (args ) and alias != repr (existing_alias ):
209+ union_str = repr (union )
210+ raise RuntimeError (
211+ f"`{ union_str } ` already has alias `{ existing_alias !r} `."
212+ )
213+
214+ new_alias = TypeAliasType (alias , union , type_params = ()) # type: ignore[misc]
215+
216+ _ALIASED_UNIONS [args ] = new_alias
217+
222218 return union
219+
220+
221+ def _transform_union_alias (x : object , / ) -> object :
222+ """Transform a Union type hint to a TypeAliasType if it's registered in the alias
223+ registry. This is used by plum's dispatch machinery to use aliased names for unions.
224+
225+ Args:
226+ x (type or type hint): Type hint, potentially a Union.
227+
228+ Returns:
229+ type or type hint: If `x` is a Union registered in `_ALIASED_UNIONS`, returns
230+ the TypeAliasType. Otherwise returns `x` unchanged.
231+ """
232+ # TypeAliasType instances are already transformed, return as-is
233+ if isinstance (x , TypeAliasType ):
234+ return x
235+
236+ # Get the union args to check if it's registered
237+ args = get_args (x ) if isinstance (x , _union_type ) else None
238+ if args :
239+ args_set = set (args )
240+ # Look for a matching alias in the registry
241+ for union_args , type_alias in _ALIASED_UNIONS .items ():
242+ if set (union_args ) == args_set :
243+ return type_alias
244+
245+ # Not a union or not aliased, return as-is
246+ return x
0 commit comments