Skip to content

Commit 617a350

Browse files
Add error context to thunk
1 parent c1b1771 commit 617a350

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

python/egglog/egraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def _fn_decl(
671671
)
672672
res_ref = ref
673673
decls.set_function_decl(ref, decl)
674-
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset)
674+
res_thunk = Thunk.fn(_create_default_value, decls, ref, fn, args, ruleset, context=f"creating {ref}")
675675
return res_ref, Thunk.fn(_add_default_rewrite_function, decls, res_ref, return_type, ruleset, res_thunk, subsume)
676676

677677

python/egglog/thunk.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ class Thunk(Generic[T, Unpack[TS]]):
4141
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
4242

4343
@classmethod
44-
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
44+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], context: str | None = None) -> Thunk[T, Unpack[TS]]:
4545
"""
4646
Create a thunk based on some functions and some partial args.
4747
4848
If the function is called while it is being resolved recursively it will raise an exception.
4949
"""
50-
return cls(Unresolved(fn, args))
50+
return cls(Unresolved(fn, args, context))
5151

5252
@classmethod
5353
def value(cls, value: T) -> Thunk[T]:
@@ -57,12 +57,12 @@ def __call__(self) -> T:
5757
match self.state:
5858
case Resolved(value):
5959
return value
60-
case Unresolved(fn, args):
60+
case Unresolved(fn, args, context):
6161
self.state = Resolving()
6262
try:
6363
res = fn(*args)
6464
except Exception as e:
65-
self.state = Error(e)
65+
self.state = Error(e, context)
6666
raise e from None
6767
else:
6868
self.state = Resolved(res)
@@ -83,6 +83,7 @@ class Resolved(Generic[T]):
8383
class Unresolved(Generic[T, Unpack[TS]]):
8484
fn: Callable[[Unpack[TS]], T]
8585
args: tuple[Unpack[TS]]
86+
context: str | None
8687

8788

8889
@dataclass
@@ -93,3 +94,4 @@ class Resolving:
9394
@dataclass
9495
class Error:
9596
e: Exception
97+
context: str | None

0 commit comments

Comments
 (0)