11from __future__ import annotations
22
3+ from collections import defaultdict
34from contextlib import contextmanager
45from contextvars import ContextVar
56from dataclasses import dataclass
6- from typing import TYPE_CHECKING , NewType , TypeVar , cast
7+ from typing import TYPE_CHECKING , TypeVar , cast
78
89from .declarations import *
910from .pretty import *
1516
1617 from .egraph import Expr
1718
18- __all__ = ["convert" , "convert_to_same_type" , "converter" , "resolve_literal" ]
19+ __all__ = ["convert" , "convert_to_same_type" , "converter" , "resolve_literal" , "ConvertError" ]
1920# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
20- TypeName = NewType ("TypeName" , str )
21- CONVERSIONS : dict [tuple [type | TypeName , TypeName ], tuple [int , Callable ]] = {}
21+ CONVERSIONS : dict [tuple [type | JustTypeRef , JustTypeRef ], tuple [int , Callable ]] = {}
2222# Global declerations to store all convertable types so we can query if they have certain methods or not
2323_CONVERSION_DECLS = Declarations .create ()
2424# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
@@ -45,12 +45,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4545 Register a converter from some type to an egglog type.
4646 """
4747 to_type_name = process_tp (to_type )
48- if not isinstance (to_type_name , str ):
48+ if not isinstance (to_type_name , JustTypeRef ):
4949 raise TypeError (f"Expected return type to be a egglog type, got { to_type_name } " )
5050 _register_converter (process_tp (from_type ), to_type_name , fn , cost )
5151
5252
53- def _register_converter (a : type | TypeName , b : TypeName , a_b : Callable , cost : int ) -> None :
53+ def _register_converter (a : type | JustTypeRef , b : JustTypeRef , a_b : Callable , cost : int ) -> None :
5454 """
5555 Registers a converter from some type to an egglog type, if not already registered.
5656
@@ -63,10 +63,26 @@ def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: in
6363 return
6464 CONVERSIONS [(a , b )] = (cost , a_b )
6565 for (c , d ), (other_cost , c_d ) in list (CONVERSIONS .items ()):
66- if b == c :
67- _register_converter (a , d , _ComposedConverter (a_b , c_d ), cost + other_cost )
68- if a == d :
69- _register_converter (c , b , _ComposedConverter (c_d , a_b ), cost + other_cost )
66+ if _is_type_compatible (b , c ):
67+ _register_converter (
68+ a , d , _ComposedConverter (a_b , c_d , c .args if isinstance (c , JustTypeRef ) else ()), cost + other_cost
69+ )
70+ if _is_type_compatible (a , d ):
71+ _register_converter (
72+ c , b , _ComposedConverter (c_d , a_b , a .args if isinstance (a , JustTypeRef ) else ()), cost + other_cost
73+ )
74+
75+
76+ def _is_type_compatible (source : type | JustTypeRef , target : type | JustTypeRef ) -> bool :
77+ """
78+ Types must be equal or also support unbound to bound typevar like B -> B[C]
79+ """
80+ if source == target :
81+ return True
82+ if isinstance (source , JustTypeRef ) and isinstance (target , JustTypeRef ) and source .args and not target .args :
83+ return source .name == target .name
84+ # TODO: Support case where B[T] where T is typevar is mapped to B[C]
85+ return False
7086
7187
7288@dataclass
@@ -81,9 +97,17 @@ class _ComposedConverter:
8197
8298 a_b : Callable
8399 b_c : Callable
100+ b_args : tuple [JustTypeRef , ...]
84101
85102 def __call__ (self , x : object ) -> object :
86- return self .b_c (self .a_b (x ))
103+ # if we have A -> B and B[C] -> D then we should use (C,) as the type args
104+ # when converting from A -> B
105+ if self .b_args :
106+ with with_type_args (self .b_args , _retrieve_conversion_decls ):
107+ first_res = self .a_b (x )
108+ else :
109+ first_res = self .a_b (x )
110+ return self .b_c (first_res )
87111
88112 def __str__ (self ) -> str :
89113 return f"{ self .b_c } ∘ { self .a_b } "
@@ -105,35 +129,33 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
105129 return resolve_literal (tp .to_var (), source , Thunk .value (target .__egg_decls__ ))
106130
107131
108- def process_tp (tp : type | RuntimeClass ) -> TypeName | type :
132+ def process_tp (tp : type | RuntimeClass ) -> JustTypeRef | type :
109133 """
110134 Process a type before converting it, to add it to the global declerations and resolve to a ref.
111135 """
112136 if isinstance (tp , RuntimeClass ):
113137 _TO_PROCESS_DECLS .append (tp )
114138 egg_tp = tp .__egg_tp__
115- if egg_tp .args :
116- raise TypeError (f"Cannot register a converter for a generic type, got { tp } " )
117- return TypeName (egg_tp .name )
139+ return egg_tp .to_just ()
118140 return tp
119141
120142
121- def min_convertable_tp (a : object , b : object , name : str ) -> TypeName :
143+ def min_convertable_tp (a : object , b : object , name : str ) -> JustTypeRef :
122144 """
123145 Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
124146 """
125147 decls = _retrieve_conversion_decls ()
126148 a_tp = _get_tp (a )
127149 b_tp = _get_tp (b )
128150 a_converts_to = {
129- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == a_tp and decls .has_method (to , name )
151+ to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == a_tp and decls .has_method (to . name , name )
130152 }
131153 b_converts_to = {
132- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == b_tp and decls .has_method (to , name )
154+ to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == b_tp and decls .has_method (to . name , name )
133155 }
134- if isinstance (a_tp , str ):
156+ if isinstance (a_tp , JustTypeRef ):
135157 a_converts_to [a_tp ] = 0
136- if isinstance (b_tp , str ):
158+ if isinstance (b_tp , JustTypeRef ):
137159 b_converts_to [b_tp ] = 0
138160 common = set (a_converts_to ) & set (b_converts_to )
139161 if not common :
@@ -176,27 +198,38 @@ def resolve_literal(
176198 # If this is a var, it has to be a runtime expession
177199 assert isinstance (arg , RuntimeExpr ), f"Expected a runtime expression, got { arg } "
178200 return arg
179- tp_name = TypeName (tp_just .name )
180- if arg_type == tp_name :
201+ if arg_type == tp_just :
181202 # If the type is an egg type, it has to be a runtime expr
182203 assert isinstance (arg , RuntimeExpr )
183204 return arg
184205 # Try all parent types as well, if we are converting from a Python type
185206 for arg_type_instance in arg_type .__mro__ if isinstance (arg_type , type ) else [arg_type ]:
186- try :
187- fn = CONVERSIONS [(arg_type_instance , tp_name )][1 ]
188- except KeyError :
189- continue
190- break
207+ if (key := (arg_type_instance , tp_just )) in CONVERSIONS :
208+ fn = CONVERSIONS [key ][1 ]
209+ break
210+ # Try broadening if we have a convert to the general type instead of the specific one too, for generics
211+ if tp_just .args and (key := (arg_type_instance , JustTypeRef (tp_just .name ))) in CONVERSIONS :
212+ fn = CONVERSIONS [key ][1 ]
213+ break
214+ # if we didn't find any raise an error
191215 else :
192- raise ConvertError (f"Cannot convert { arg_type } to { tp_name } " )
216+ raise ConvertError (f"Cannot convert { arg_type } to { tp_just } " )
193217 with with_type_args (tp_just .args , decls ):
194218 return fn (arg )
195219
196220
197- def _get_tp (x : object ) -> TypeName | type :
221+ def _debug_print_converers ():
222+ """
223+ Prints a mapping of all source types to target types that have a conversion function.
224+ """
225+ source_to_targets = defaultdict (list )
226+ for source , target in CONVERSIONS :
227+ source_to_targets [source ].append (target )
228+
229+
230+ def _get_tp (x : object ) -> JustTypeRef | type :
198231 if isinstance (x , RuntimeExpr ):
199- return TypeName ( x .__egg_typed_expr__ .tp . name )
232+ return x .__egg_typed_expr__ .tp
200233 tp = type (x )
201234 # If this value has a custom metaclass, let's use that as our index instead of the type
202235 if type (tp ) is not type :
0 commit comments