Skip to content

Commit 0bf4827

Browse files
authored
[Feature Enhancement] add rp expr parser tools for subgraph capture (#213)
* add rp expr parser tools for subgraph capture * rp_expr: 3 error tests / 5 passed tests * rp_expr: keep passed unittest
1 parent b386ec2 commit 0bf4827

File tree

8 files changed

+1416
-0
lines changed

8 files changed

+1416
-0
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import unittest
2+
from graph_net.torch.rp_expr import Tokenize
3+
from graph_net.torch.rp_expr.rp_expr_passes import (
4+
FlattenTokenListPass,
5+
FoldTokensPass,
6+
RecursiveFoldTokensPass,
7+
FoldIfTokenIdGreatEqualPass,
8+
)
9+
from graph_net.torch.rp_expr.nested_range import Range, Tree
10+
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
11+
from graph_net.torch.rp_expr.rp_expr_util import (
12+
MakeNestedIndexRangeFromLetsListTokenRpExpr,
13+
)
14+
15+
16+
class TestTokenize(unittest.TestCase):
17+
"""Tests tokenization of primitive ID lists into symbolic token sequences."""
18+
19+
def test_simple(self):
20+
primitive_id_lists = [list(range(10 + i)) for i in range(5)]
21+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
22+
self.assertEqual(len(token_list.tensors), len(primitive_id_lists))
23+
24+
25+
class TestFlattenTokenListPass(unittest.TestCase):
26+
"""Tests flattening of nested token structures into linear sequences."""
27+
28+
def test_simple(self):
29+
base = 10
30+
size = 5
31+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
32+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
33+
rp_expr_pass = FlattenTokenListPass(id_allocator)
34+
success, flattened_rp_expr_pass = rp_expr_pass(token_list)
35+
self.assertTrue(success)
36+
self.assertEqual(id_allocator.NextTokenId(), base + 2 * size - 1)
37+
38+
39+
class TestFoldTokensPass(unittest.TestCase):
40+
"""Tests folding of the most frequent contiguous token pattern into a single symbol."""
41+
42+
def test_simple(self):
43+
base = 3
44+
size = 3
45+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
46+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
47+
flatten_pass = FlattenTokenListPass(id_allocator)
48+
_, flattened_rp_expr = flatten_pass(token_list)
49+
fold_pass = FoldTokensPass(id_allocator)
50+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
51+
self.assertTrue(success)
52+
input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist()
53+
pattern = fold_rp_expr.symbol_token_tensors[0].numpy().tolist()
54+
replacement = fold_rp_expr.symbol_token_ids[0]
55+
output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
56+
self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7])
57+
self.assertEqual(pattern, [3, 4, 5])
58+
self.assertEqual(replacement, 8)
59+
self.assertEqual(output, [8, 1, 8, 6, 2, 8, 6, 7])
60+
61+
62+
class TestRecursiveFoldTokensPass(unittest.TestCase):
63+
"""Tests recursive folding of repeated patterns at multiple levels of nesting."""
64+
65+
def test_simple(self):
66+
base = 3
67+
size = 3
68+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
69+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
70+
flatten_pass = FlattenTokenListPass(id_allocator)
71+
_, flattened_rp_expr = flatten_pass(token_list)
72+
fold_pass = RecursiveFoldTokensPass(id_allocator)
73+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
74+
self.assertTrue(success)
75+
input = flattened_rp_expr.flattened_tensor.tensor.numpy().tolist()
76+
pattern = [x.numpy().tolist() for x in fold_rp_expr.symbol_token_tensors]
77+
replacement = fold_rp_expr.symbol_token_ids
78+
output = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
79+
self.assertEqual(input, [3, 4, 5, 1, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7])
80+
self.assertEqual(pattern, [[3, 4, 5], [8, 6]])
81+
self.assertEqual(replacement, [8, 9])
82+
self.assertEqual(output, [8, 1, 9, 2, 9, 7])
83+
84+
85+
class TestFoldIfTokenIdGreatEqualPass(unittest.TestCase):
86+
"""Tests conditional folding only for tokens with ID greater than or equal to a threshold."""
87+
88+
def test_simple(self):
89+
base = 3
90+
size = 3
91+
primitive_id_lists = [list(range(base + i)) for i in range(size)]
92+
token_list, id_allocator, _ = Tokenize(primitive_id_lists)
93+
flatten_pass = FlattenTokenListPass(id_allocator)
94+
_, flattened_rp_expr = flatten_pass(token_list)
95+
fold_pass = RecursiveFoldTokensPass(id_allocator)
96+
success, fold_rp_expr = fold_pass(flattened_rp_expr.flattened_tensor)
97+
self.assertTrue(success)
98+
threshold_fold_pass = FoldIfTokenIdGreatEqualPass(
99+
id_allocator=id_allocator,
100+
threshold_start_token_id=len(primitive_id_lists),
101+
)
102+
success, threshold_fold_rp_expr = threshold_fold_pass(fold_rp_expr.body_rp_expr)
103+
self.assertTrue(success)
104+
input = fold_rp_expr.body_rp_expr.tensor.numpy().tolist()
105+
pattern = [
106+
x.numpy().tolist() for x in threshold_fold_rp_expr.symbol_token_tensors
107+
]
108+
replacement = threshold_fold_rp_expr.symbol_token_ids
109+
self.assertEqual(len(threshold_fold_rp_expr.body_rp_expr), 3)
110+
output = [x.numpy().tolist() for x in threshold_fold_rp_expr.body_rp_expr]
111+
self.assertEqual(input, [8, 1, 9, 2, 9, 7])
112+
self.assertEqual(pattern, [[9, 7]])
113+
self.assertEqual(replacement, [10])
114+
self.assertEqual(output, [[8], [9], [10]])
115+
116+
117+
if __name__ == "__main__":
118+
unittest.main()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .rp_expr import Tokenize
2+
from .rp_expr_parser import RpExprParser
3+
from .nested_range import Range, Tree
4+
from .rp_expr_util import MakeNestedIndexRangeFromLetsListTokenRpExpr
5+
from .rp_expr_passes import (
6+
FlattenTokenListPass,
7+
FoldTokensPass,
8+
RecursiveFoldTokensPass,
9+
FoldIfTokenIdGreatEqualPass,
10+
)
11+
12+
__all__ = [
13+
"Tokenize",
14+
"RpExprParser",
15+
"Range",
16+
"Tree",
17+
"MakeNestedIndexRangeFromLetsListTokenRpExpr",
18+
"FlattenTokenListPass",
19+
"FoldTokensPass",
20+
"RecursiveFoldTokensPass",
21+
"FoldIfTokenIdGreatEqualPass",
22+
]
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import typing as t
2+
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
3+
from graph_net.torch.rp_expr.rp_expr import PrimitiveId, LetsListTokenRpExpr
4+
import numpy as np
5+
import sys
6+
7+
8+
class LongestRpExprParser:
9+
def __init__(self, max_window_size=1024, min_window_size=4):
10+
self.max_window_size = max_window_size
11+
self.min_window_size = min_window_size
12+
13+
def __call__(self, primitive_id_lists: t.List[t.List[PrimitiveId]]):
14+
fold_policy = "default"
15+
rp_expr_parser = RpExprParser(
16+
self.max_window_size,
17+
fold_policy=fold_policy,
18+
fold_times=1,
19+
)
20+
lets_list_rp_expr, token_id2primitive_id = rp_expr_parser(primitive_id_lists)
21+
for window_size in self._get_sub_window_sizes():
22+
rp_expr_parser = RpExprParser(
23+
window_size,
24+
fold_policy=fold_policy,
25+
fold_times=1,
26+
)
27+
cur_primitive_id_lists = [
28+
[token_id2primitive_id[token_id] for token_id in tensor.tolist()]
29+
for tensor in lets_list_rp_expr.get_pure_primitive_binding_tensors(
30+
token_id2primitive_id
31+
)
32+
]
33+
cur_lets_list_rp_expr, cur_token_id2primitive_id = rp_expr_parser(
34+
cur_primitive_id_lists
35+
)
36+
# cur_lets_list_rp_expr.try_unwrap_body_of_sole_symbol_token()
37+
lets_list_rp_expr = self._merge_lets_list_rp_expr(
38+
inner=cur_lets_list_rp_expr,
39+
outer=lets_list_rp_expr,
40+
inner_token_id2primitive_id=cur_token_id2primitive_id,
41+
outer_token_id2primitive_id=token_id2primitive_id,
42+
)
43+
lets_list_rp_expr.try_recursive_inline_symbol_sole_used(
44+
token_id2primitive_id=token_id2primitive_id
45+
)
46+
# lets_list_rp_expr.try_unwrap_body_of_sole_symbol_token()
47+
return lets_list_rp_expr, token_id2primitive_id
48+
49+
def _merge_lets_list_rp_expr(
50+
self,
51+
inner,
52+
outer,
53+
inner_token_id2primitive_id,
54+
outer_token_id2primitive_id,
55+
):
56+
def get_inner_token_id2outer_token_id():
57+
primitive_id2outer_token_id = {}
58+
for token_id, primitive_id in enumerate(outer_token_id2primitive_id):
59+
assert primitive_id not in primitive_id2outer_token_id
60+
primitive_id2outer_token_id[primitive_id] = token_id
61+
return [
62+
primitive_id2outer_token_id[primitive_id]
63+
for primitive_id in inner_token_id2primitive_id
64+
]
65+
66+
kInner = "inner"
67+
kOuter = "outer"
68+
uid2new_symbol_token = self._make_uid2new_symbol_token_id(
69+
inner=inner,
70+
outer=outer,
71+
inner_uid_prefix=kInner,
72+
outer_uid_prefix=kOuter,
73+
outer_primitive_table_size=len(outer_token_id2primitive_id),
74+
)
75+
inner_symbol_token_ids = self._convert_symbol_token_ids(
76+
symbol_token_ids=inner.symbol_token_ids,
77+
new_token4old_token=(
78+
lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"]
79+
),
80+
)
81+
inner_token_id2outer_token_id = get_inner_token_id2outer_token_id()
82+
inner_symbol_token_tensors = self._convert_token_tensors(
83+
inner.symbol_token_tensors,
84+
new_token4old_primitive_token=(
85+
lambda old_token: inner_token_id2outer_token_id[old_token]
86+
),
87+
new_token4old_symbol_token=(
88+
lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"]
89+
),
90+
primitive_ids_table_size=len(inner_token_id2primitive_id),
91+
)
92+
93+
inner_body_rp_expr = self._convert_token_tensors(
94+
inner.body_rp_expr,
95+
new_token4old_primitive_token=(
96+
lambda old_token: inner_token_id2outer_token_id[old_token]
97+
),
98+
new_token4old_symbol_token=(
99+
lambda old_token: uid2new_symbol_token[f"{kInner}{old_token}"]
100+
),
101+
primitive_ids_table_size=len(inner_token_id2primitive_id),
102+
)
103+
104+
inner_symbol_token2token_tensor = {
105+
symbol_token: token_tensor
106+
for symbol_token, token_tensor in zip(
107+
inner_symbol_token_ids, inner_symbol_token_tensors
108+
)
109+
}
110+
111+
outer_symbol_token_tensors = self._convert_outer_symbol_binding_token_tensors(
112+
inner_body_rp_expr=inner_body_rp_expr,
113+
inner_symbol_token2token_tensor=inner_symbol_token2token_tensor,
114+
outer_lets_list_rp_expr=outer,
115+
new_token4old_primitive_token=lambda x: x,
116+
new_token4old_symbol_token=(
117+
lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"]
118+
),
119+
outer_token_id2primitive_id=outer_token_id2primitive_id,
120+
)
121+
122+
symbol_token_ids = inner_symbol_token_ids + self._convert_symbol_token_ids(
123+
symbol_token_ids=outer.symbol_token_ids,
124+
new_token4old_token=(
125+
lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"]
126+
),
127+
)
128+
129+
symbol_token_tensors = inner_symbol_token_tensors + outer_symbol_token_tensors
130+
131+
body_rp_expr = self._convert_token_tensors(
132+
outer.body_rp_expr,
133+
new_token4old_primitive_token=lambda x: x,
134+
new_token4old_symbol_token=(
135+
lambda old_token: uid2new_symbol_token[f"{kOuter}{old_token}"]
136+
),
137+
primitive_ids_table_size=len(outer_token_id2primitive_id),
138+
)
139+
ret_lets_list_token_rp_expr = LetsListTokenRpExpr(
140+
symbol_token_ids=symbol_token_ids,
141+
symbol_token_tensors=symbol_token_tensors,
142+
body_rp_expr=body_rp_expr,
143+
)
144+
ret_lets_list_token_rp_expr.move_pure_primitive_bindings_front(
145+
outer_token_id2primitive_id
146+
)
147+
return ret_lets_list_token_rp_expr
148+
149+
def _convert_outer_symbol_binding_token_tensors(
150+
self,
151+
inner_body_rp_expr,
152+
inner_symbol_token2token_tensor,
153+
outer_lets_list_rp_expr,
154+
new_token4old_primitive_token,
155+
new_token4old_symbol_token,
156+
outer_token_id2primitive_id,
157+
):
158+
indexes = outer_lets_list_rp_expr.get_pure_primitive_binding_indexes(
159+
outer_token_id2primitive_id
160+
)
161+
assert len(inner_body_rp_expr) == len(indexes)
162+
index2inner_body_rp_expr_idx = {
163+
index: inner_body_rp_expr_idx
164+
for inner_body_rp_expr_idx, index in enumerate(indexes)
165+
}
166+
old_tensors = outer_lets_list_rp_expr.symbol_token_tensors
167+
return [
168+
(
169+
inner_body_rp_expr[index2inner_body_rp_expr_idx[index]]
170+
if index in index2inner_body_rp_expr_idx
171+
else self._convert_token_tensor(
172+
tensor=old_tensors[index],
173+
new_token4old_primitive_token=new_token4old_primitive_token,
174+
new_token4old_symbol_token=new_token4old_symbol_token,
175+
primitive_ids_table_size=len(outer_token_id2primitive_id),
176+
)
177+
)
178+
for index in range(len(old_tensors))
179+
]
180+
181+
def _convert_token_tensors(
182+
self,
183+
tensors,
184+
new_token4old_primitive_token,
185+
new_token4old_symbol_token,
186+
primitive_ids_table_size,
187+
):
188+
return [
189+
self._convert_token_tensor(
190+
tensor,
191+
new_token4old_primitive_token,
192+
new_token4old_symbol_token,
193+
primitive_ids_table_size,
194+
)
195+
for tensor in tensors
196+
]
197+
198+
def _convert_token_tensor(
199+
self,
200+
tensor,
201+
new_token4old_primitive_token,
202+
new_token4old_symbol_token,
203+
primitive_ids_table_size,
204+
):
205+
return np.array(
206+
[
207+
(
208+
new_token4old_primitive_token(token_id)
209+
if token_id < primitive_ids_table_size
210+
else new_token4old_symbol_token(token_id)
211+
)
212+
for token_id in tensor.tolist()
213+
],
214+
dtype=np.int64,
215+
)
216+
217+
def _make_uid2new_symbol_token_id(
218+
self,
219+
inner,
220+
outer,
221+
inner_uid_prefix,
222+
outer_uid_prefix,
223+
outer_primitive_table_size,
224+
):
225+
new_symbol_token_id = outer_primitive_table_size
226+
227+
def get_new_symbol_token_id():
228+
nonlocal new_symbol_token_id
229+
ret = new_symbol_token_id
230+
new_symbol_token_id += 1
231+
return ret
232+
233+
uid2new_symbol_token_id = {}
234+
for inner_symbol_token_id in inner.symbol_token_ids:
235+
uid = f"{inner_uid_prefix}{inner_symbol_token_id}"
236+
uid2new_symbol_token_id[uid] = get_new_symbol_token_id()
237+
for outer_symbol_token_id in outer.symbol_token_ids:
238+
uid = f"{outer_uid_prefix}{outer_symbol_token_id}"
239+
uid2new_symbol_token_id[uid] = get_new_symbol_token_id()
240+
return uid2new_symbol_token_id
241+
242+
def _convert_symbol_token_ids(self, symbol_token_ids, new_token4old_token):
243+
return [
244+
new_token4old_token(symbol_token_id) for symbol_token_id in symbol_token_ids
245+
]
246+
247+
def _get_sub_window_sizes(self):
248+
min_window_size = max(1, self.min_window_size)
249+
window_size = self.max_window_size // 2
250+
while window_size > min_window_size:
251+
yield window_size
252+
window_size = window_size // 2

0 commit comments

Comments
 (0)