Skip to content

Commit eb5f9b8

Browse files
Save snapshots in testing even if doesn't finish
1 parent 04da652 commit eb5f9b8

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

python/egglog/exp/array_api_jit.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,22 @@
1414
X = TypeVar("X", bound=Callable)
1515

1616

17-
def jit(fn: X) -> X:
17+
def jit(
18+
fn: X,
19+
*,
20+
handle_expr: Callable[[NDArray], None] | None = None,
21+
handle_optimized_expr: Callable[[NDArray], None] | None = None,
22+
) -> X:
1823
"""
1924
Jit compiles a function
2025
"""
2126
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27+
if handle_expr:
28+
handle_expr(res)
29+
if handle_optimized_expr:
30+
handle_optimized_expr(res_optimized)
2231
fn_program = EvalProgram(program, {"np": np})
23-
fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
24-
fn.initial_expr = res # type: ignore[attr-defined]
25-
fn.expr = res_optimized # type: ignore[attr-defined]
26-
return fn
32+
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
2733

2834

2935
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:

python/tests/test_array_api.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: disable-error-code="empty-body"
22
import inspect
33
from collections.abc import Callable
4+
from functools import partial
45
from itertools import product
56
from pathlib import Path
67
from types import FunctionType
@@ -352,9 +353,12 @@ def lda(X: NDArray, y: NDArray):
352353
],
353354
)
354355
def test_jit(program, snapshot_py, benchmark):
355-
jitted = benchmark(jit, program)
356-
assert str(jitted.initial_expr) == snapshot_py(name="initial_expr")
357-
assert str(jitted.expr) == snapshot_py(name="expr")
356+
def save_expr(name, expr):
357+
assert str(expr) == snapshot_py(name=name)
358+
359+
jitted = benchmark(
360+
jit, program, handle_expr=partial(save_expr, "initial_expr"), handle_optimized_expr=partial(save_expr, "expr")
361+
)
358362
assert inspect.getsource(jitted) == snapshot_py(name="code")
359363

360364

0 commit comments

Comments
 (0)