Skip to content

Commit d40508c

Browse files
Merge branch 'main' into wma/scheduler
2 parents 77760f2 + cc04d92 commit d40508c

File tree

5 files changed

+151
-18
lines changed

5 files changed

+151
-18
lines changed

src/finchlite/symbolic/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
)
1212
from .stage import Stage
1313
from .term import (
14-
PostOrderDFS,
15-
PreOrderDFS,
1614
Term,
1715
TermTree,
1816
literal_repr,
1917
)
18+
from .traversal import PostOrderDFS, PreOrderDFS, intree, isdescendant
2019

2120
__all__ = [
2221
"BasicBlock",
@@ -42,5 +41,7 @@
4241
"fisinstance",
4342
"ftype",
4443
"gensym",
44+
"intree",
45+
"isdescendant",
4546
"literal_repr",
4647
]

src/finchlite/symbolic/environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
from collections import defaultdict
44
from typing import Any, Generic, Optional, TypeVar
55

6-
from .term import PostOrderDFS, Term
6+
from .term import Term
7+
from .traversal import PostOrderDFS
78

89

910
class NamedTerm(Term, ABC):

src/finchlite/symbolic/term.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from collections.abc import Iterator
54
from dataclasses import dataclass
65
from inspect import isbuiltin, isclass, isfunction
76
from typing import Any, Self
@@ -82,17 +81,3 @@ def literal_repr(name: str, fields: dict[str, Any]) -> str:
8281
return (
8382
name + "(" + ", ".join([f"{k}={_get_repr(v)}" for k, v in fields.items()]) + ")"
8483
)
85-
86-
87-
def PostOrderDFS(node: Term) -> Iterator[Term]:
88-
if isinstance(node, TermTree):
89-
for arg in node.children:
90-
yield from PostOrderDFS(arg)
91-
yield node
92-
93-
94-
def PreOrderDFS(node: Term) -> Iterator[Term]:
95-
yield node
96-
if isinstance(node, TermTree):
97-
for arg in node.children:
98-
yield from PreOrderDFS(arg)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections.abc import Iterator
2+
3+
from finchlite.symbolic.term import Term, TermTree
4+
5+
6+
def PostOrderDFS(node: Term) -> Iterator[Term]:
7+
if isinstance(node, TermTree):
8+
for arg in node.children:
9+
yield from PostOrderDFS(arg)
10+
yield node
11+
12+
13+
def PreOrderDFS(node: Term) -> Iterator[Term]:
14+
yield node
15+
if isinstance(node, TermTree):
16+
for arg in node.children:
17+
yield from PreOrderDFS(arg)
18+
19+
20+
def intree(n1, n2):
21+
"""
22+
Return True iff `n1` occurs in the subtree rooted at `n2`.
23+
"""
24+
return any(node == n1 for node in PostOrderDFS(n2))
25+
26+
27+
def isdescendant(n1, n2):
28+
"""
29+
True iff `n1` is a strict descendant of `n2`.
30+
"""
31+
if n1 == n2:
32+
return False
33+
return intree(n1, n2)

tests/test_traversal.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from collections import Counter
2+
3+
from finchlite.finch_logic import Field, Literal, MapJoin, Plan, Produces, Table
4+
from finchlite.symbolic import PostOrderDFS, PreOrderDFS, intree, isdescendant
5+
6+
7+
def test_preorder_dfs():
8+
ta = Table(
9+
Literal("A"),
10+
(Field("i"), Field("j")),
11+
)
12+
13+
tb = Table(
14+
Literal("B"),
15+
(Field("j"), Field("k")),
16+
)
17+
18+
prog = Plan(
19+
(
20+
Produces(
21+
(
22+
MapJoin(
23+
Field("op"),
24+
(ta, tb),
25+
),
26+
),
27+
),
28+
)
29+
)
30+
31+
preorder = list(PreOrderDFS(prog))
32+
33+
assert Counter(type(x).__name__ for x in preorder) == Counter(
34+
{"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5}
35+
)
36+
37+
pos = {}
38+
for i, obj in enumerate(preorder):
39+
k = id(obj)
40+
if k in pos:
41+
continue
42+
pos[k] = i
43+
for node in preorder:
44+
for child in getattr(node, "children", ()):
45+
assert pos[id(node)] < pos[id(child)]
46+
47+
48+
def test_postorder_dfs():
49+
ta = Table(
50+
Literal("A"),
51+
(Field("i"), Field("j")),
52+
)
53+
54+
tb = Table(
55+
Literal("B"),
56+
(Field("j"), Field("k")),
57+
)
58+
59+
prog = Plan(
60+
(
61+
Produces(
62+
(
63+
MapJoin(
64+
Field("op"),
65+
(ta, tb),
66+
),
67+
),
68+
),
69+
)
70+
)
71+
72+
postorder = list(PostOrderDFS(prog))
73+
74+
assert Counter(type(x).__name__ for x in postorder) == Counter(
75+
{"Plan": 1, "Produces": 1, "MapJoin": 1, "Table": 2, "Literal": 2, "Field": 5}
76+
)
77+
78+
pos = {}
79+
for i, obj in enumerate(postorder):
80+
k = id(obj)
81+
if k in pos:
82+
continue
83+
pos[k] = i
84+
for node in postorder:
85+
for child in getattr(node, "children", ()):
86+
assert pos[id(child)] < pos[id(node)]
87+
88+
89+
def test_intree():
90+
i, j, k = Field("i"), Field("j"), Field("k")
91+
ta = Table(Literal("A"), (i, j))
92+
tb = Table(Literal("B"), (j, k))
93+
op = Field("op")
94+
mj = MapJoin(op, (ta, tb))
95+
prog = Plan((Produces((mj,)),))
96+
97+
assert intree(prog, prog)
98+
assert intree(mj, prog)
99+
assert intree(ta, prog)
100+
assert intree(tb, prog)
101+
102+
103+
def test_isdescendant():
104+
i, j, k = Field("i"), Field("j"), Field("k")
105+
ta = Table(Literal("A"), (i, j))
106+
tb = Table(Literal("B"), (j, k))
107+
op = Field("op")
108+
mj = MapJoin(op, (ta, tb))
109+
prog = Plan((Produces((mj,)),))
110+
111+
assert isdescendant(mj, prog)
112+
assert isdescendant(ta, prog)
113+
assert isdescendant(tb, prog)

0 commit comments

Comments
 (0)