1
1
from __future__ import annotations
2
2
3
3
from collections import defaultdict
4
+ from collections .abc import Callable
4
5
from contextlib import contextmanager
5
6
from contextvars import ContextVar
6
7
from dataclasses import dataclass
7
- from typing import TYPE_CHECKING , TypeVar , cast
8
+ from typing import TYPE_CHECKING , Any , TypeVar , cast
8
9
9
10
from .declarations import *
10
11
from .pretty import *
13
14
from .type_constraint_solver import TypeConstraintError
14
15
15
16
if TYPE_CHECKING :
16
- from collections .abc import Callable , Generator
17
+ from collections .abc import Generator
17
18
18
19
from .egraph import BaseExpr
19
20
from .type_constraint_solver import TypeConstraintSolver
20
21
21
- __all__ = ["ConvertError" , "convert" , "convert_to_same_type" , " converter" , "resolve_literal " ]
22
+ __all__ = ["ConvertError" , "convert" , "converter" , "get_type_args " ]
22
23
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23
- CONVERSIONS : dict [tuple [type | JustTypeRef , JustTypeRef ], tuple [int , Callable ]] = {}
24
+ CONVERSIONS : dict [tuple [type | JustTypeRef , JustTypeRef ], tuple [int , Callable [[ Any ], RuntimeExpr ] ]] = {}
24
25
# Global declerations to store all convertable types so we can query if they have certain methods or not
25
26
_CONVERSION_DECLS = Declarations .create ()
26
27
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
27
28
# until we need them
28
29
_TO_PROCESS_DECLS : list [DeclerationsLike ] = []
29
30
30
31
31
- def _retrieve_conversion_decls () -> Declarations :
32
+ def retrieve_conversion_decls () -> Declarations :
32
33
_CONVERSION_DECLS .update (* _TO_PROCESS_DECLS )
33
34
_TO_PROCESS_DECLS .clear ()
34
35
return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
49
50
to_type_name = process_tp (to_type )
50
51
if not isinstance (to_type_name , JustTypeRef ):
51
52
raise TypeError (f"Expected return type to be a egglog type, got { to_type_name } " )
52
- _register_converter (process_tp (from_type ), to_type_name , fn , cost )
53
+ _register_converter (process_tp (from_type ), to_type_name , cast ( "Callable[[Any], RuntimeExpr]" , fn ) , cost )
53
54
54
55
55
- def _register_converter (a : type | JustTypeRef , b : JustTypeRef , a_b : Callable , cost : int ) -> None :
56
+ def _register_converter (a : type | JustTypeRef , b : JustTypeRef , a_b : Callable [[ Any ], RuntimeExpr ] , cost : int ) -> None :
56
57
"""
57
58
Registers a converter from some type to an egglog type, if not already registered.
58
59
@@ -97,15 +98,15 @@ class _ComposedConverter:
97
98
We use the dataclass instead of the lambda to make it easier to debug.
98
99
"""
99
100
100
- a_b : Callable
101
- b_c : Callable
101
+ a_b : Callable [[ Any ], RuntimeExpr ]
102
+ b_c : Callable [[ Any ], RuntimeExpr ]
102
103
b_args : tuple [JustTypeRef , ...]
103
104
104
- def __call__ (self , x : object ) -> object :
105
+ def __call__ (self , x : Any ) -> RuntimeExpr :
105
106
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
106
107
# when converting from A -> B
107
108
if self .b_args :
108
- with with_type_args (self .b_args , _retrieve_conversion_decls ):
109
+ with with_type_args (self .b_args , retrieve_conversion_decls ):
109
110
first_res = self .a_b (x )
110
111
else :
111
112
first_res = self .a_b (x )
@@ -142,33 +143,38 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142
143
return tp
143
144
144
145
145
- def min_convertable_tp (a : object , b : object , name : str ) -> JustTypeRef :
146
- """
147
- Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
148
- """
149
- decls = _retrieve_conversion_decls ()
150
- a_tp = _get_tp (a )
151
- b_tp = _get_tp (b )
152
- # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153
- if not (
154
- (isinstance (a_tp , JustTypeRef ) and decls .has_method (a_tp .name , name ))
155
- or (isinstance (b_tp , JustTypeRef ) and decls .has_method (b_tp .name , name ))
156
- ):
157
- raise ConvertError (f"Neither { a_tp } nor { b_tp } has method { name } " )
158
- a_converts_to = {
159
- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == a_tp and decls .has_method (to .name , name )
160
- }
161
- b_converts_to = {
162
- to : c for ((from_ , to ), (c , _ )) in CONVERSIONS .items () if from_ == b_tp and decls .has_method (to .name , name )
163
- }
164
- if isinstance (a_tp , JustTypeRef ) and decls .has_method (a_tp .name , name ):
165
- a_converts_to [a_tp ] = 0
166
- if isinstance (b_tp , JustTypeRef ) and decls .has_method (b_tp .name , name ):
167
- b_converts_to [b_tp ] = 0
168
- common = set (a_converts_to ) & set (b_converts_to )
169
- if not common :
170
- raise ConvertError (f"Cannot convert { a_tp } and { b_tp } to a common type" )
171
- return min (common , key = lambda tp : a_converts_to [tp ] + b_converts_to [tp ])
146
+ # def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
147
+ # """
148
+ # Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
149
+ # """
150
+ # decls = _retrieve_conversion_decls().copy()
151
+ # if isinstance(a, RuntimeExpr):
152
+ # decls |= a
153
+ # if isinstance(b, RuntimeExpr):
154
+ # decls |= b
155
+
156
+ # a_tp = _get_tp(a)
157
+ # b_tp = _get_tp(b)
158
+ # # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
159
+ # if not (
160
+ # (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
161
+ # or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
162
+ # ):
163
+ # raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
164
+ # a_converts_to = {
165
+ # to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
166
+ # }
167
+ # b_converts_to = {
168
+ # to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
169
+ # }
170
+ # if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
171
+ # a_converts_to[a_tp] = 0
172
+ # if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
173
+ # b_converts_to[b_tp] = 0
174
+ # common = set(a_converts_to) & set(b_converts_to)
175
+ # if not common:
176
+ # raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
177
+ # return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172
178
173
179
174
180
def identity (x : object ) -> object :
@@ -197,7 +203,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197
203
def resolve_literal (
198
204
tp : TypeOrVarRef ,
199
205
arg : object ,
200
- decls : Callable [[], Declarations ] = _retrieve_conversion_decls ,
206
+ decls : Callable [[], Declarations ] = retrieve_conversion_decls ,
201
207
tcs : TypeConstraintSolver | None = None ,
202
208
cls_name : str | None = None ,
203
209
) -> RuntimeExpr :
@@ -208,12 +214,12 @@ def resolve_literal(
208
214
209
215
If it cannot be resolved, we assume that the value passed in will resolve it.
210
216
"""
211
- arg_type = _get_tp (arg )
217
+ arg_type = resolve_type (arg )
212
218
213
219
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
214
220
try :
215
221
tp_just = tp .to_just ()
216
- except NotImplementedError :
222
+ except TypeVarError :
217
223
# If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218
224
# args first based on the existing type constraint solver
219
225
if tcs :
@@ -258,7 +264,7 @@ def _debug_print_converers():
258
264
source_to_targets [source ].append (target )
259
265
260
266
261
- def _get_tp (x : object ) -> JustTypeRef | type :
267
+ def resolve_type (x : object ) -> JustTypeRef | type :
262
268
if isinstance (x , RuntimeExpr ):
263
269
return x .__egg_typed_expr__ .tp
264
270
tp = type (x )
0 commit comments