Skip to content

Commit fd6e1ff

Browse files
Try fixing Python 3.10 compat
1 parent 617a350 commit fd6e1ff

File tree

7 files changed

+129
-28
lines changed

7 files changed

+129
-28
lines changed

docs/changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ _This project uses semantic versioning_
44

55
## UNRELEASED
66

7-
- fix using `f64Like` when not importing star (also properly includes removal of `Callable` special case from previous release).
7+
- Fix using `f64Like` when not importing star (also properly includes removal of `Callable` special case from previous release).
8+
- Fix Python 3.10 compatibility
89

910
## 10.0.1 (2025-04-06)
1011

python/egglog/egraph.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
ClassVar,
1717
Generic,
1818
Literal,
19-
Never,
2019
TypeAlias,
2120
TypedDict,
2221
TypeVar,
@@ -26,7 +25,7 @@
2625
)
2726

2827
import graphviz
29-
from typing_extensions import ParamSpec, Self, Unpack, assert_never
28+
from typing_extensions import Never, ParamSpec, Self, Unpack, assert_never
3029

3130
from . import bindings
3231
from .conversion import *
@@ -36,6 +35,7 @@
3635
from .pretty import pretty_decl
3736
from .runtime import *
3837
from .thunk import *
38+
from .version_compat import *
3939

4040
if TYPE_CHECKING:
4141
from .builtins import String, Unit
@@ -169,8 +169,9 @@ def check_eq(x: BASE_EXPR, y: BASE_EXPR, schedule: Schedule | None = None, *, ad
169169
except bindings.EggSmolError as err:
170170
if display:
171171
egraph.display()
172-
err.add_note(f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})")
173-
raise
172+
raise add_note(
173+
f"Failed:\n{eq(x).to(y)}\n\nExtracted:\n {eq(egraph.extract(x)).to(egraph.extract(y))})", err
174+
) from None
174175
return egraph
175176

176177

@@ -492,8 +493,7 @@ def _generate_class_decls( # noqa: C901,PLR0912
492493
reverse_args=reverse_args,
493494
)
494495
except Exception as e:
495-
e.add_note(f"Error processing {cls_name}.{method_name}")
496-
raise
496+
raise add_note(f"Error processing {cls_name}.{method_name}", e) from None
497497

498498
if not builtin and not isinstance(ref, InitRef) and not mutates:
499499
add_default_funcs.append(add_rewrite)
@@ -1040,8 +1040,7 @@ def _run_extract(self, expr: RuntimeExpr, n: int) -> bindings._ExtractReport:
10401040
bindings.ActionCommand(bindings.Extract(span(2), expr, bindings.Lit(span(2), bindings.Int(n))))
10411041
)
10421042
except BaseException as e:
1043-
e.add_note("Extracting: " + str(expr))
1044-
raise
1043+
raise add_note("Extracting: " + str(expr), e) # noqa: B904
10451044
extract_report = self._egraph.extract_report()
10461045
if not extract_report:
10471046
msg = "No extract report saved"

python/egglog/exp/array_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969

7070
from egglog import *
7171
from egglog.runtime import RuntimeExpr
72+
from egglog.version_compat import add_note
7273

7374
from .program_gen import *
7475

@@ -1198,13 +1199,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
11981199

11991200
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
12001201

1201-
converter(NDArray, IndexKey, IndexKey.ndarray)
1202-
converter(Value, NDArray, NDArray.scalar)
1202+
converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1203+
converter(Value, NDArray, lambda v: NDArray.scalar(v))
12031204
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
12041205
# to prefer upcasting in the other direction when we can, which is safer at runtime
12051206
converter(NDArray, Value, lambda n: n.to_value(), 100)
1206-
converter(TupleValue, NDArray, NDArray.vector)
1207-
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
1207+
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1208+
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
12081209

12091210

12101211
@array_api_ruleset.register
@@ -1383,8 +1384,8 @@ def int(cls, value: Int) -> IntOrTuple: ...
13831384
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
13841385

13851386

1386-
converter(Int, IntOrTuple, IntOrTuple.int)
1387-
converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
1387+
converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1388+
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
13881389

13891390

13901391
class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
@@ -1395,7 +1396,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
13951396

13961397

13971398
converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1398-
converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
1399+
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
13991400

14001401

14011402
@function
@@ -1980,6 +1981,5 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
19801981
extracted = egraph.extract(prim_expr)
19811982
except BaseException as e:
19821983
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1983-
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1984-
raise
1984+
raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
19851985
return extracted.eval() # type: ignore[attr-defined]

python/egglog/runtime.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .pretty import *
2323
from .thunk import Thunk
2424
from .type_constraint_solver import *
25+
from .version_compat import *
2526

2627
if TYPE_CHECKING:
2728
from collections.abc import Iterable
@@ -249,8 +250,7 @@ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
249250
try:
250251
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
251252
except Exception as e:
252-
e.add_note(f"Error processing class {self.__egg_tp__.name}")
253-
raise
253+
raise add_note(f"Error processing class {self.__egg_tp__.name}", e) from None
254254

255255
preserved_methods = cls_decl.preserved_methods
256256
if name in preserved_methods:
@@ -318,8 +318,7 @@ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs:
318318
try:
319319
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
320320
except Exception as e:
321-
e.add_note(f"Failed to find callable {self}")
322-
raise
321+
raise add_note(f"Failed to find callable {self}", e) # noqa: B904
323322
decls = self.__egg_decls__.copy()
324323
# Special case function application bc we dont support variadic generics yet generally
325324
if signature == "fn-app":

python/egglog/type_constraint_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def substitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None = None) ->
107107
try:
108108
return self._cls_typevar_index_to_type[cls_name][tp]
109109
except KeyError as e:
110-
raise TypeConstraintError(f"Not enough bound typevars for {tp} in class {cls_name}") from e
110+
raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_name}") from e
111111
case TypeRefWithVars(name, args):
112112
return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_name) for arg in args))
113113
assert_never(tp)

python/egglog/version_compat.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import collections
2+
import sys
3+
import types
4+
import typing
5+
6+
BEFORE_3_11 = sys.version_info < (3, 11)
7+
8+
__all__ = ["add_note"]
9+
10+
11+
def add_note(message: str, exc: BaseException) -> BaseException:
12+
"""
13+
Backwards compatible add_note for Python <= 3.10
14+
"""
15+
if BEFORE_3_11:
16+
return exc
17+
exc.add_note(message)
18+
return exc
19+
20+
21+
# For Python version 3.10 need to monkeypatch this function so that RuntimeClass type parameters
22+
# will be collected as typevars
23+
if BEFORE_3_11:
24+
25+
@typing.no_type_check
26+
def _collect_type_vars_monkeypatch(types_, typevar_types=None):
27+
"""
28+
Collect all type variable contained
29+
in types in order of first appearance (lexicographic order). For example::
30+
31+
_collect_type_vars((T, List[S, T])) == (T, S)
32+
"""
33+
from .runtime import RuntimeClass
34+
35+
if typevar_types is None:
36+
typevar_types = typing.TypeVar
37+
tvars = []
38+
for t in types_:
39+
if isinstance(t, typevar_types) and t not in tvars:
40+
tvars.append(t)
41+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
42+
if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)): # type: ignore[name-defined]
43+
tvars.extend([t for t in t.__parameters__ if t not in tvars])
44+
return tuple(tvars)
45+
46+
typing._collect_type_vars = _collect_type_vars_monkeypatch # type: ignore[attr-defined]
47+
48+
@typing.no_type_check
49+
@typing._tp_cache
50+
def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
51+
from .runtime import RuntimeClass
52+
53+
if self.__origin__ in (typing.Generic, typing.Protocol):
54+
# Can't subscript Generic[...] or Protocol[...].
55+
raise TypeError(f"Cannot subscript already-subscripted {self}")
56+
if not isinstance(params, tuple):
57+
params = (params,)
58+
params = tuple(typing._type_convert(p) for p in params)
59+
if self._paramspec_tvars and any(isinstance(t, typing.ParamSpec) for t in self.__parameters__):
60+
params = typing._prepare_paramspec_params(self, params)
61+
else:
62+
typing._check_generic(self, params, len(self.__parameters__))
63+
64+
subst = dict(zip(self.__parameters__, params, strict=False))
65+
new_args = []
66+
for arg in self.__args__:
67+
if isinstance(arg, self._typevar_types):
68+
if isinstance(arg, typing.ParamSpec):
69+
arg = subst[arg] # noqa: PLW2901
70+
if not typing._is_param_expr(arg):
71+
raise TypeError(f"Expected a list of types, an ellipsis, ParamSpec, or Concatenate. Got {arg}")
72+
else:
73+
arg = subst[arg] # noqa: PLW2901
74+
# **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
75+
elif isinstance(arg, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
76+
subparams = arg.__parameters__
77+
if subparams:
78+
subargs = tuple(subst[x] for x in subparams)
79+
arg = arg[subargs] # noqa: PLW2901
80+
# Required to flatten out the args for CallableGenericAlias
81+
if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple):
82+
new_args.extend(arg)
83+
else:
84+
new_args.append(arg)
85+
return self.copy_with(tuple(new_args))
86+
87+
typing._GenericAlias.__getitem__ = __getitem__monkeypatch # type: ignore[attr-defined]

python/tests/test_high_level.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pathlib
66
from copy import copy
77
from fractions import Fraction
8-
from typing import ClassVar, TypeAlias
8+
from typing import ClassVar, TypeAlias, TypeVar
99

1010
import pytest
1111

@@ -17,6 +17,7 @@
1717
MethodRef,
1818
TypedExprDecl,
1919
)
20+
from egglog.version_compat import BEFORE_3_11
2021

2122

2223
class TestExprStr:
@@ -793,10 +794,11 @@ class E(Expr):
793794
@function(cost=10)
794795
def __init__(self) -> None: ...
795796

796-
with pytest.raises(
797-
ValueError,
798-
match="Inside of classes, wrap methods with the `method` decorator, not `function`\nError processing E.__init__",
799-
):
797+
match = "Inside of classes, wrap methods with the `method` decorator, not `function`"
798+
# If we are after 3 11 we have context included
799+
if not BEFORE_3_11:
800+
match += "\nError processing E.__init__"
801+
with pytest.raises(ValueError, match=match):
800802
E()
801803

802804

@@ -858,6 +860,19 @@ def __eq__(self, other: B) -> B: ... # type: ignore[override]
858860
assert not isinstance(B() == B(), Fact)
859861

860862

863+
T = TypeVar("T")
864+
865+
866+
def test_type_param_sub():
867+
"""
868+
Verify that type substituion works properly, by comparing string version.
869+
870+
Comparing actual versions is always false if they are no the same object for unions
871+
"""
872+
V = Vec[T] | int
873+
assert str(V[Unit]) == str(Vec[Unit] | int) # type: ignore[misc]
874+
875+
861876
EXAMPLE_FILES = list((pathlib.Path(__file__).parent / "../egglog/examples").glob("*.py"))
862877

863878

0 commit comments

Comments
 (0)