11from __future__ import annotations
22
33from collections import defaultdict
4+ from collections .abc import Callable
45from contextlib import contextmanager
56from contextvars import ContextVar
67from dataclasses import dataclass
7- from typing import TYPE_CHECKING , TypeVar , cast
8+ from typing import TYPE_CHECKING , Any , TypeVar , cast
89
910from .declarations import *
1011from .pretty import *
1314from .type_constraint_solver import TypeConstraintError
1415
1516if TYPE_CHECKING :
16- from collections .abc import Callable , Generator
17+ from collections .abc import Generator
1718
1819 from .egraph import BaseExpr
1920 from .type_constraint_solver import TypeConstraintSolver
2021
21- __all__ = ["ConvertError" , "convert" , "convert_to_same_type" , " converter" , "resolve_literal " ]
22+ __all__ = ["ConvertError" , "convert" , "converter" , "get_type_args " ]
2223# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23- CONVERSIONS : dict [tuple [type | JustTypeRef , JustTypeRef ], tuple [int , Callable ]] = {}
24+ CONVERSIONS : dict [tuple [type | JustTypeRef , JustTypeRef ], tuple [int , Callable [[ Any ], RuntimeExpr ] ]] = {}
2425# Global declerations to store all convertable types so we can query if they have certain methods or not
2526_CONVERSION_DECLS = Declarations .create ()
2627# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
2728# until we need them
2829_TO_PROCESS_DECLS : list [DeclerationsLike ] = []
2930
3031
31- def _retrieve_conversion_decls () -> Declarations :
32+ def retrieve_conversion_decls () -> Declarations :
3233 _CONVERSION_DECLS .update (* _TO_PROCESS_DECLS )
3334 _TO_PROCESS_DECLS .clear ()
3435 return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4950 to_type_name = process_tp (to_type )
5051 if not isinstance (to_type_name , JustTypeRef ):
5152 raise TypeError (f"Expected return type to be a egglog type, got { to_type_name } " )
52- _register_converter (process_tp (from_type ), to_type_name , fn , cost )
53+ _register_converter (process_tp (from_type ), to_type_name , cast ( "Callable[[Any], RuntimeExpr]" , fn ) , cost )
5354
5455
55- def _register_converter (a : type | JustTypeRef , b : JustTypeRef , a_b : Callable , cost : int ) -> None :
56+ def _register_converter (a : type | JustTypeRef , b : JustTypeRef , a_b : Callable [[ Any ], RuntimeExpr ] , cost : int ) -> None :
5657 """
5758 Registers a converter from some type to an egglog type, if not already registered.
5859
@@ -97,15 +98,15 @@ class _ComposedConverter:
9798 We use the dataclass instead of the lambda to make it easier to debug.
9899 """
99100
100- a_b : Callable
101- b_c : Callable
101+ a_b : Callable [[ Any ], RuntimeExpr ]
102+ b_c : Callable [[ Any ], RuntimeExpr ]
102103 b_args : tuple [JustTypeRef , ...]
103104
104- def __call__ (self , x : object ) -> object :
105+ def __call__ (self , x : Any ) -> RuntimeExpr :
105106 # if we have A -> B and B[C] -> D then we should use (C,) as the type args
106107 # when converting from A -> B
107108 if self .b_args :
108- with with_type_args (self .b_args , _retrieve_conversion_decls ):
109+ with with_type_args (self .b_args , retrieve_conversion_decls ):
109110 first_res = self .a_b (x )
110111 else :
111112 first_res = self .a_b (x )
@@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142143 return tp
143144
144145
145- def min_convertable_tp (a : object , b : object , name : str ) -> JustTypeRef :
146- """
147- Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148- """
149- decls = _retrieve_conversion_decls ()
150- a_tp = _get_tp (a )
151- b_tp = _get_tp (b )
152- # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153- if not (
154- (isinstance (a_tp , JustTypeRef ) and decls .has_method (a_tp .name , name ))
155- or (isinstance (b_tp , JustTypeRef ) and decls .has_method (b_tp .name , name ))
156- ):
157- raise ConvertError (f"Neither { a_tp } nor { b_tp } has method { name } " )
158- a_converts_to = {
159- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == a_tp and decls .has_method (to .name , name )
160- }
161- b_converts_to = {
162- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == b_tp and decls .has_method (to .name , name )
163- }
164- if isinstance (a_tp , JustTypeRef ) and decls .has_method (a_tp .name , name ):
165- a_converts_to [a_tp ] = 0
166- if isinstance (b_tp , JustTypeRef ) and decls .has_method (b_tp .name , name ):
167- b_converts_to [b_tp ] = 0
168- common = set (a_converts_to ) & set (b_converts_to )
169- if not common :
170- raise ConvertError (f"Cannot convert { a_tp } and { b_tp } to a common type" )
171- return min (common , key = lambda tp : a_converts_to [tp ] + b_converts_to [tp ])
146+ # def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147+ # """
148+ # Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149+ # """
150+ # decls = _retrieve_conversion_decls().copy()
151+ # if isinstance(a, RuntimeExpr):
152+ # decls |= a
153+ # if isinstance(b, RuntimeExpr):
154+ # decls |= b
155+
156+ # a_tp = _get_tp(a)
157+ # b_tp = _get_tp(b)
158+ # # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159+ # if not (
160+ # (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161+ # or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162+ # ):
163+ # raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164+ # a_converts_to = {
165+ # to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166+ # }
167+ # b_converts_to = {
168+ # to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169+ # }
170+ # if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171+ # a_converts_to[a_tp] = 0
172+ # if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173+ # b_converts_to[b_tp] = 0
174+ # common = set(a_converts_to) & set(b_converts_to)
175+ # if not common:
176+ # raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177+ # return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172178
173179
174180def identity (x : object ) -> object :
@@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197203def resolve_literal (
198204 tp : TypeOrVarRef ,
199205 arg : object ,
200- decls : Callable [[], Declarations ] = _retrieve_conversion_decls ,
206+ decls : Callable [[], Declarations ] = retrieve_conversion_decls ,
201207 tcs : TypeConstraintSolver | None = None ,
202208 cls_name : str | None = None ,
203209) -> RuntimeExpr :
@@ -208,12 +214,12 @@ def resolve_literal(
208214
209215 If it cannot be resolved, we assume that the value passed in will resolve it.
210216 """
211- arg_type = _get_tp (arg )
217+ arg_type = resolve_type (arg )
212218
213219 # If we have any type variables, dont bother trying to resolve the literal, just return the arg
214220 try :
215221 tp_just = tp .to_just ()
216- except NotImplementedError :
222+ except TypeVarError :
217223 # If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218224 # args first based on the existing type constraint solver
219225 if tcs :
@@ -258,7 +264,7 @@ def _debug_print_converers():
258264 source_to_targets [source ].append (target )
259265
260266
261- def _get_tp (x : object ) -> JustTypeRef | type :
267+ def resolve_type (x : object ) -> JustTypeRef | type :
262268 if isinstance (x , RuntimeExpr ):
263269 return x .__egg_typed_expr__ .tp
264270 tp = type (x )
0 commit comments