Skip to content

Commit 9009e8a

Browse files
authored
✅ test: add runtime checker (#3)
1 parent a085939 commit 9009e8a

File tree

10 files changed

+154
-1
lines changed

10 files changed

+154
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,4 @@ print(ast.unparse(simplified_tree))
5252

5353
## TODOs
5454

55-
- [ ] Add runtime checks in uts
55+
- [ ] Automatically generate test cases

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ combine-as-imports = true
8080
[tool.ruff.lint.per-file-ignores]
8181
"setup.py" = ["I"]
8282

83+
[tool.pytest.ini_options]
84+
python_files = ["tests/*.py", "tests/**/*.py"]
85+
8386
[build-system]
8487
requires = ["hatchling"]
8588
build-backend = "hatchling.build"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from __future__ import annotations
2+
3+
from expr_simplifier.analyzer.extern_symbol_analyzer import analyze_external_symbols as analyze_external_symbols
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
5+
from expr_simplifier.symbol_table import SymbolTable
6+
7+
8+
class ExternalSymbolAnalyzer(ast.NodeVisitor):
9+
def __init__(self):
10+
self.symbol_table = SymbolTable()
11+
self.external_symbols: list[str] = []
12+
13+
def visit_Name(self, node: ast.Name) -> None:
14+
if isinstance(node.ctx, ast.Store):
15+
self.symbol_table.define_symbol(node.id)
16+
elif isinstance(node.ctx, ast.Load):
17+
if not self.symbol_table.is_symbol_defined(node.id) and node.id not in self.external_symbols:
18+
self.external_symbols.append(node.id)
19+
20+
21+
def analyze_external_symbols(tree: ast.AST) -> list[str]:
22+
analyzer = ExternalSymbolAnalyzer()
23+
analyzer.visit(tree)
24+
return analyzer.external_symbols

tests/test_transforms/test_constant_folding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from expr_simplifier.transforms import apply_constant_folding
88
from expr_simplifier.utils import loop_until_stable
99

10+
from .utils import check_expr_at_runtime
11+
1012

1113
@pytest.mark.parametrize(
1214
["expr", "expected"],
@@ -27,6 +29,7 @@ def test_constant_folding(expr: str, expected: str):
2729
transformed_tree = apply_constant_folding(tree)
2830
transformed_expr = ast.unparse(transformed_tree)
2931
assert transformed_expr == expected
32+
check_expr_at_runtime(tree, transformed_tree)
3033

3134

3235
@pytest.mark.parametrize(
@@ -41,3 +44,4 @@ def test_constant_folding_loop_until_stable(expr: str, expected: str, max_iter:
4144
transformed_tree = loop_until_stable(tree, [apply_constant_folding], max_iter)
4245
transformed_expr = ast.unparse(transformed_tree)
4346
assert transformed_expr == expected
47+
check_expr_at_runtime(tree, transformed_tree)

tests/test_transforms/test_cse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from expr_simplifier.transforms import apply_cse
88

9+
from .utils import check_expr_at_runtime
10+
911

1012
@pytest.mark.parametrize(
1113
["expr", "expected"],
@@ -36,3 +38,4 @@ def test_cse(expr: str, expected: str):
3638
transformed_tree = apply_cse(tree)
3739
transformed_expr = ast.unparse(transformed_tree)
3840
assert transformed_expr == expected
41+
check_expr_at_runtime(tree, transformed_tree)

tests/test_transforms/test_inline_named_expr.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
apply_inline_all_named_expr,
1010
)
1111

12+
from .utils import check_expr_at_runtime
13+
1214

1315
@pytest.mark.parametrize(
1416
["expr", "expected"],
@@ -25,6 +27,7 @@ def test_inline_named_expr(expr: str, expected: str):
2527
transformed_tree = apply_inline_all_named_expr(tree)
2628
transformed_expr = ast.unparse(transformed_tree)
2729
assert transformed_expr == expected
30+
check_expr_at_runtime(tree, transformed_tree)
2831

2932

3033
@pytest.mark.parametrize(
@@ -43,3 +46,4 @@ def test_constant_propagation(expr: str, expected: str):
4346
transformed_tree = apply_constant_propagation(tree)
4447
transformed_expr = ast.unparse(transformed_tree)
4548
assert transformed_expr == expected
49+
check_expr_at_runtime(tree, transformed_tree)

tests/test_transforms/test_logical_simplification.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
apply_remove_same_subexpression_in_logical_op,
1111
)
1212

13+
from .utils import check_expr_at_runtime
14+
1315

1416
@pytest.mark.parametrize(
1517
["expr", "expected"],
@@ -35,6 +37,7 @@ def test_logical_short_circuiting(expr: str, expected: str):
3537
transformed_tree = apply_logical_short_circuiting(tree)
3638
transformed_expr = ast.unparse(transformed_tree)
3739
assert transformed_expr == expected
40+
check_expr_at_runtime(tree, transformed_tree)
3841

3942

4043
@pytest.mark.parametrize(
@@ -60,6 +63,7 @@ def test_remove_same_subexpression_in_logical_op(expr: str, expected: str):
6063
transformed_tree = apply_remove_same_subexpression_in_logical_op(tree)
6164
transformed_expr = ast.unparse(transformed_tree)
6265
assert transformed_expr == expected
66+
check_expr_at_runtime(tree, transformed_tree)
6367

6468

6569
@pytest.mark.parametrize(
@@ -74,3 +78,4 @@ def test_logical_simplification(expr: str, expected: str):
7478
transformed_tree = apply_logical_simplification(tree)
7579
transformed_expr = ast.unparse(transformed_tree)
7680
assert transformed_expr == expected
81+
check_expr_at_runtime(tree, transformed_tree)

tests/test_transforms/test_remove_unused_named_expr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from expr_simplifier.transforms import apply_remove_unused_named_expr
88

9+
from .utils import check_expr_at_runtime
10+
911

1012
@pytest.mark.parametrize(
1113
["expr", "expected"],
@@ -20,3 +22,4 @@ def test_inline_named_expr(expr: str, expected: str):
2022
transformed_tree = apply_remove_unused_named_expr(tree)
2123
transformed_expr = ast.unparse(transformed_tree)
2224
assert transformed_expr == expected
25+
check_expr_at_runtime(tree, transformed_tree)

tests/test_transforms/utils.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
from typing import Any
5+
6+
from expr_simplifier.analyzer import analyze_external_symbols
7+
8+
9+
class AnyObject:
10+
"""A top type that can represent any object in Python"""
11+
12+
def __init__(self, name: str):
13+
self.name = name
14+
15+
def __add__(self, other: Any):
16+
other = AnyObject.to_any_object(other)
17+
return AnyObject(f"{self.name} + {other.name}")
18+
19+
def __radd__(self, other: Any):
20+
other = AnyObject.to_any_object(other)
21+
return AnyObject(f"{other.name} + {self.name}")
22+
23+
def __mul__(self, other: Any):
24+
other = AnyObject.to_any_object(other)
25+
return AnyObject(f"{self.name} * {other.name}")
26+
27+
def __rmul__(self, other: Any):
28+
other = AnyObject.to_any_object(other)
29+
return AnyObject(f"{other.name} * {self.name}")
30+
31+
def __truediv__(self, other: Any):
32+
other = AnyObject.to_any_object(other)
33+
return AnyObject(f"{self.name} / {other.name}")
34+
35+
def __rtruediv__(self, other: Any):
36+
other = AnyObject.to_any_object(other)
37+
return AnyObject(f"{other.name} / {self.name}")
38+
39+
def __floordiv__(self, other: Any):
40+
other = AnyObject.to_any_object(other)
41+
return AnyObject(f"{self.name} // {other.name}")
42+
43+
def __rfloordiv__(self, other: Any):
44+
other = AnyObject.to_any_object(other)
45+
return AnyObject(f"{other.name} // {self.name}")
46+
47+
def __getattr__(self, name: str):
48+
return AnyObject(f"{self.name}.{name}")
49+
50+
def __getitem__(self, key: Any):
51+
key = AnyObject.to_any_object(key)
52+
return AnyObject(f"{self.name}[{key.name}]")
53+
54+
def __call__(self, *args: Any, **kwargs: Any):
55+
formatted_args = ", ".join(AnyObject.to_any_object(arg).name for arg in args)
56+
formatted_kwargs = ", ".join(f"{key}={AnyObject.to_any_object(value).name}" for key, value in kwargs.items())
57+
return AnyObject(f"{self.name}({formatted_args}, {formatted_kwargs})")
58+
59+
def __bool__(self):
60+
return hash(self.name) % 2 == 0
61+
62+
def __hash__(self):
63+
return hash(self.name)
64+
65+
# def __eq__(self, other: object):
66+
# other = AnyObject.to_any_object(other)
67+
# return AnyObject(f"{self.name} == {other.name}")
68+
69+
# def __ne__(self, other: object):
70+
# other = AnyObject.to_any_object(other)
71+
# return AnyObject(f"{self.name} != {other.name}")
72+
73+
def __lt__(self, other: Any):
74+
other = AnyObject.to_any_object(other)
75+
return AnyObject(f"{self.name} < {other.name}")
76+
77+
def __le__(self, other: Any):
78+
other = AnyObject.to_any_object(other)
79+
return AnyObject(f"{self.name} <= {other.name}")
80+
81+
def __repr__(self):
82+
return f"AnyObject({self.name})"
83+
84+
@staticmethod
85+
def to_any_object(value: Any):
86+
if isinstance(value, AnyObject):
87+
return value
88+
return AnyObject.from_raw(value)
89+
90+
@staticmethod
91+
def from_raw(value: Any):
92+
return AnyObject(f"$R({value!r})")
93+
94+
95+
def check_expr_at_runtime(original_ast: ast.AST, transformed_ast: ast.AST) -> None:
96+
original_externel_symbols = analyze_external_symbols(original_ast)
97+
transformed_externel_symbols = analyze_external_symbols(transformed_ast)
98+
assert original_externel_symbols == transformed_externel_symbols, "External symbols are different"
99+
original_expr = ast.unparse(original_ast)
100+
transformed_expr = ast.unparse(transformed_ast)
101+
externel_fake_values = {symbol: AnyObject(symbol) for symbol in original_externel_symbols}
102+
original_expr_result = eval(original_expr, externel_fake_values)
103+
transformed_expr_result = eval(transformed_expr, externel_fake_values)
104+
assert repr(original_expr_result) == repr(transformed_expr_result), "Results are different"

0 commit comments

Comments
 (0)