Skip to content

Commit 5f44c26

Browse files
Fix conversion to/from generic types
1 parent 96bf07f commit 5f44c26

18 files changed

+408
-178
lines changed

docs/changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ _This project uses semantic versioning_
1212
- Add working loopnest test and rewrite array api suport to be more general
1313
- Improve tracebacks on failing conversions.
1414
- Use `add_note` for exception to add more context, instead of raising a new exception, to make it easier to debug.
15-
- Add conversions from vecs (TODO)
15+
- Add conversions from generic types to be supported at runtime and typing level (so can go from `(1, 2, 3)` to `TupleInt`)
1616
- Open files with webbrowser instead of internal graphviz util for better support
1717

1818
## 8.0.1 (2024-10-24)

python/egglog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from . import config, ipython_magic # noqa: F401
66
from .builtins import * # noqa: UP029
7-
from .conversion import convert, converter # noqa: F401
7+
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
88
from .egraph import *
99

1010
del ipython_magic

python/egglog/builtins.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from __future__ import annotations
77

8-
from functools import partial
8+
from functools import partial, reduce
99
from types import FunctionType
1010
from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, cast, overload
1111

@@ -25,13 +25,16 @@
2525
"Bool",
2626
"BoolLike",
2727
"Map",
28+
"MapLike",
2829
"PyObject",
2930
"Rational",
3031
"Set",
32+
"SetLike",
3133
"String",
3234
"StringLike",
3335
"UnstableFn",
3436
"Vec",
37+
"VecLike",
3538
"f64",
3639
"f64Like",
3740
"i64",
@@ -285,6 +288,22 @@ def remove(self, key: T) -> Map[T, V]: ...
285288
def rebuild(self) -> Map[T, V]: ...
286289

287290

291+
TO = TypeVar("TO")
292+
VO = TypeVar("VO")
293+
294+
converter(
295+
dict,
296+
Map,
297+
lambda t: reduce(
298+
(lambda acc, kv: acc.insert(convert(kv[0], get_type_args()[0]), convert(kv[1], get_type_args()[1]))),
299+
t.items(),
300+
Map[get_type_args()].empty(), # type: ignore[misc]
301+
),
302+
)
303+
304+
MapLike: TypeAlias = Map[T, V] | dict[TO, VO]
305+
306+
288307
class Set(Expr, Generic[T], builtin=True):
289308
@method(egg_fn="set-of")
290309
def __init__(self, *args: T) -> None: ...
@@ -318,6 +337,17 @@ def __and__(self, other: Set[T]) -> Set[T]: ...
318337
def rebuild(self) -> Set[T]: ...
319338

320339

340+
converter(
341+
set,
342+
Set,
343+
lambda t: Set[get_type_args()[0]]( # type: ignore[misc,operator]
344+
*(convert(x, get_type_args()[0]) for x in t)
345+
),
346+
)
347+
348+
SetLike: TypeAlias = Set[T] | set[TO]
349+
350+
321351
class Rational(Expr, builtin=True):
322352
@method(egg_fn="rational")
323353
def __init__(self, num: i64Like, den: i64Like) -> None: ...
@@ -418,7 +448,16 @@ def remove(self, index: i64Like) -> Vec[T]: ...
418448
def set(self, index: i64Like, value: T) -> Vec[T]: ...
419449

420450

421-
converter(tuple, Vec, lambda t: Vec(*(convert(x, get_type_args()[0]) for x in t)))
451+
for sequence_type in (list, tuple):
452+
converter(
453+
sequence_type,
454+
Vec,
455+
lambda t: Vec[get_type_args()[0]]( # type: ignore[misc,operator]
456+
*(convert(x, get_type_args()[0]) for x in t)
457+
),
458+
)
459+
460+
VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
422461

423462

424463
class PyObject(Expr, builtin=True):

python/egglog/conversion.py

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
from collections import defaultdict
34
from contextlib import contextmanager
45
from contextvars import ContextVar
56
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, NewType, TypeVar, cast
7+
from typing import TYPE_CHECKING, TypeVar, cast
78

89
from .declarations import *
910
from .pretty import *
@@ -15,10 +16,9 @@
1516

1617
from .egraph import Expr
1718

18-
__all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal"]
19+
__all__ = ["convert", "convert_to_same_type", "converter", "resolve_literal", "ConvertError"]
1920
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
20-
TypeName = NewType("TypeName", str)
21-
CONVERSIONS: dict[tuple[type | TypeName, TypeName], tuple[int, Callable]] = {}
21+
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
2222
# Global declerations to store all convertable types so we can query if they have certain methods or not
2323
_CONVERSION_DECLS = Declarations.create()
2424
# Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
@@ -45,12 +45,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
4545
Register a converter from some type to an egglog type.
4646
"""
4747
to_type_name = process_tp(to_type)
48-
if not isinstance(to_type_name, str):
48+
if not isinstance(to_type_name, JustTypeRef):
4949
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
5050
_register_converter(process_tp(from_type), to_type_name, fn, cost)
5151

5252

53-
def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: int) -> None:
53+
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
5454
"""
5555
Registers a converter from some type to an egglog type, if not already registered.
5656
@@ -63,10 +63,26 @@ def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: in
6363
return
6464
CONVERSIONS[(a, b)] = (cost, a_b)
6565
for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
66-
if b == c:
67-
_register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
68-
if a == d:
69-
_register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)
66+
if _is_type_compatible(b, c):
67+
_register_converter(
68+
a, d, _ComposedConverter(a_b, c_d, c.args if isinstance(c, JustTypeRef) else ()), cost + other_cost
69+
)
70+
if _is_type_compatible(a, d):
71+
_register_converter(
72+
c, b, _ComposedConverter(c_d, a_b, a.args if isinstance(a, JustTypeRef) else ()), cost + other_cost
73+
)
74+
75+
76+
def _is_type_compatible(source: type | JustTypeRef, target: type | JustTypeRef) -> bool:
77+
"""
78+
Types must be equal or also support unbound to bound typevar like B -> B[C]
79+
"""
80+
if source == target:
81+
return True
82+
if isinstance(source, JustTypeRef) and isinstance(target, JustTypeRef) and source.args and not target.args:
83+
return source.name == target.name
84+
# TODO: Support case where B[T] where T is typevar is mapped to B[C]
85+
return False
7086

7187

7288
@dataclass
@@ -81,9 +97,17 @@ class _ComposedConverter:
8197

8298
a_b: Callable
8399
b_c: Callable
100+
b_args: tuple[JustTypeRef, ...]
84101

85102
def __call__(self, x: object) -> object:
86-
return self.b_c(self.a_b(x))
103+
# if we have A -> B and B[C] -> D then we should use (C,) as the type args
104+
# when converting from A -> B
105+
if self.b_args:
106+
with with_type_args(self.b_args, _retrieve_conversion_decls):
107+
first_res = self.a_b(x)
108+
else:
109+
first_res = self.a_b(x)
110+
return self.b_c(first_res)
87111

88112
def __str__(self) -> str:
89113
return f"{self.b_c}{self.a_b}"
@@ -105,35 +129,33 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
105129
return resolve_literal(tp.to_var(), source, Thunk.value(target.__egg_decls__))
106130

107131

108-
def process_tp(tp: type | RuntimeClass) -> TypeName | type:
132+
def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
109133
"""
110134
Process a type before converting it, to add it to the global declerations and resolve to a ref.
111135
"""
112136
if isinstance(tp, RuntimeClass):
113137
_TO_PROCESS_DECLS.append(tp)
114138
egg_tp = tp.__egg_tp__
115-
if egg_tp.args:
116-
raise TypeError(f"Cannot register a converter for a generic type, got {tp}")
117-
return TypeName(egg_tp.name)
139+
return egg_tp.to_just()
118140
return tp
119141

120142

121-
def min_convertable_tp(a: object, b: object, name: str) -> TypeName:
143+
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
122144
"""
123145
Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
124146
"""
125147
decls = _retrieve_conversion_decls()
126148
a_tp = _get_tp(a)
127149
b_tp = _get_tp(b)
128150
a_converts_to = {
129-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to, name)
151+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
130152
}
131153
b_converts_to = {
132-
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to, name)
154+
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
133155
}
134-
if isinstance(a_tp, str):
156+
if isinstance(a_tp, JustTypeRef):
135157
a_converts_to[a_tp] = 0
136-
if isinstance(b_tp, str):
158+
if isinstance(b_tp, JustTypeRef):
137159
b_converts_to[b_tp] = 0
138160
common = set(a_converts_to) & set(b_converts_to)
139161
if not common:
@@ -176,27 +198,38 @@ def resolve_literal(
176198
# If this is a var, it has to be a runtime expession
177199
assert isinstance(arg, RuntimeExpr), f"Expected a runtime expression, got {arg}"
178200
return arg
179-
tp_name = TypeName(tp_just.name)
180-
if arg_type == tp_name:
201+
if arg_type == tp_just:
181202
# If the type is an egg type, it has to be a runtime expr
182203
assert isinstance(arg, RuntimeExpr)
183204
return arg
184205
# Try all parent types as well, if we are converting from a Python type
185206
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
186-
try:
187-
fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
188-
except KeyError:
189-
continue
190-
break
207+
if (key := (arg_type_instance, tp_just)) in CONVERSIONS:
208+
fn = CONVERSIONS[key][1]
209+
break
210+
# Try broadening if we have a convert to the general type instead of the specific one too, for generics
211+
if tp_just.args and (key := (arg_type_instance, JustTypeRef(tp_just.name))) in CONVERSIONS:
212+
fn = CONVERSIONS[key][1]
213+
break
214+
# if we didn't find any raise an error
191215
else:
192-
raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
216+
raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
193217
with with_type_args(tp_just.args, decls):
194218
return fn(arg)
195219

196220

197-
def _get_tp(x: object) -> TypeName | type:
221+
def _debug_print_converers():
222+
"""
223+
Prints a mapping of all source types to target types that have a conversion function.
224+
"""
225+
source_to_targets = defaultdict(list)
226+
for source, target in CONVERSIONS:
227+
source_to_targets[source].append(target)
228+
229+
230+
def _get_tp(x: object) -> JustTypeRef | type:
198231
if isinstance(x, RuntimeExpr):
199-
return TypeName(x.__egg_typed_expr__.tp.name)
232+
return x.__egg_typed_expr__.tp
200233
tp = type(x)
201234
# If this value has a custom metaclass, let's use that as our index instead of the type
202235
if type(tp) is not type:

0 commit comments

Comments
 (0)