6969
7070from egglog import *
7171from egglog .runtime import RuntimeExpr
72+ from egglog .version_compat import add_note
7273
7374from .program_gen import *
7475
@@ -1198,13 +1199,13 @@ def if_(cls, b: BooleanLike, i: NDArrayLike, j: NDArrayLike) -> NDArray: ...
11981199
11991200NDArrayLike : TypeAlias = NDArray | ValueLike | TupleValueLike
12001201
1201- converter (NDArray , IndexKey , IndexKey .ndarray )
1202- converter (Value , NDArray , NDArray .scalar )
1202+ converter (NDArray , IndexKey , lambda v : IndexKey .ndarray ( v ) )
1203+ converter (Value , NDArray , lambda v : NDArray .scalar ( v ) )
12031204# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
12041205# to prefer upcasting in the other direction when we can, which is safer at runtime
12051206converter (NDArray , Value , lambda n : n .to_value (), 100 )
1206- converter (TupleValue , NDArray , NDArray .vector )
1207- converter (TupleInt , TupleValue , TupleValue .from_tuple_int )
1207+ converter (TupleValue , NDArray , lambda v : NDArray .vector ( v ) )
1208+ converter (TupleInt , TupleValue , lambda v : TupleValue .from_tuple_int ( v ) )
12081209
12091210
12101211@array_api_ruleset .register
@@ -1383,8 +1384,8 @@ def int(cls, value: Int) -> IntOrTuple: ...
13831384 def tuple (cls , value : TupleIntLike ) -> IntOrTuple : ...
13841385
13851386
1386- converter (Int , IntOrTuple , IntOrTuple .int )
1387- converter (TupleInt , IntOrTuple , IntOrTuple .tuple )
1387+ converter (Int , IntOrTuple , lambda v : IntOrTuple .int ( v ) )
1388+ converter (TupleInt , IntOrTuple , lambda v : IntOrTuple .tuple ( v ) )
13881389
13891390
13901391class OptionalIntOrTuple (Expr , ruleset = array_api_ruleset ):
@@ -1395,7 +1396,7 @@ def some(cls, value: IntOrTuple) -> OptionalIntOrTuple: ...
13951396
13961397
13971398converter (type (None ), OptionalIntOrTuple , lambda _ : OptionalIntOrTuple .none )
1398- converter (IntOrTuple , OptionalIntOrTuple , OptionalIntOrTuple .some )
1399+ converter (IntOrTuple , OptionalIntOrTuple , lambda v : OptionalIntOrTuple .some ( v ) )
13991400
14001401
14011402@function
@@ -1980,6 +1981,5 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
19801981 extracted = egraph .extract (prim_expr )
19811982 except BaseException as e :
19821983 # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1983- e .add_note (f"Cannot evaluate { egraph .extract (expr )} " )
1984- raise
1984+ raise add_note (f"Cannot evaluate { egraph .extract (expr )} " , e ) # noqa: B904
19851985 return extracted .eval () # type: ignore[attr-defined]
0 commit comments