Skip to content

Commit c24ad82

Browse files
authored
add min_seq_ops and max_seq_ops for typical_sequence_split_points.py (#451)
* init 'symbolic_dimension_reifier' field in graph_net.json * remove unused files * add fx_graph_module_unserialize_test.sh * add min_seq_ops and max_seq_ops for typical_sequence_split_points.py
1 parent 0c63b54 commit c24ad82

File tree

3 files changed

+52
-20
lines changed

3 files changed

+52
-20
lines changed

graph_net/test/dev_model_list/validation_error_model_list.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ samples/torchgeometric/RECT_L
22
samples/transformers-auto-model/bge-small-en-v1.5
33
samples/transformers-auto-model/distilbert_distilbert-base-multilingual-cased
44
samples/transformers-auto-model/OFA-Sys_chinese-clip-vit-large-patch14
5-
samples/transformers-auto-model/opus-mt-ase-es
5+
samples/transformers-auto-model/miangoar_esm2_t12_35M_UR50D-finetuned-secondary-structure-classification
66
samples/transformers-auto-model/opus-mt-en-gv
77
samples/transformers-auto-model/opus-mt-en-phi
88
samples/transformers-auto-model/opus-mt-en-sal

graph_net/test/typical_sequence_decomposer_test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ python3 -m graph_net.torch.typical_sequence_split_points \
3333
--window-size 10 \
3434
--fold-policy default \
3535
--fold-times 10 \
36+
--min-seq-ops 4 \
37+
--max-seq-ops 16 \
3638
--output-json "$DECOMPOSE_PATH/split_results.json"
3739

3840
python3 -m graph_net.model_path_handler \

graph_net/torch/typical_sequence_split_points.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import torch
77
import torch.nn as nn
88
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
9+
from graph_net.torch.rp_expr.rp_expr_util import (
10+
MakeNestedIndexRangeFromLetsListTokenRpExpr,
11+
)
912
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
1013
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify
1114

@@ -92,11 +95,18 @@ def _extract_ops(self, model_path: str) -> List[str]:
9295

9396
class SplitAnalyzer:
9497
def __init__(
95-
self, window_size: int = 10, fold_policy: str = "default", fold_times: int = 0
98+
self,
99+
window_size: int = 10,
100+
fold_policy: str = "default",
101+
fold_times: int = 0,
102+
min_seq_ops: int = 2,
103+
max_seq_ops: int = 64,
96104
):
97105
self.window_size = window_size
98106
self.fold_policy = fold_policy
99107
self.fold_times = fold_times
108+
self.min_seq_ops = min_seq_ops
109+
self.max_seq_ops = max_seq_ops
100110

101111
def _resolve_token_to_ops(
102112
self, tid, num_primitives, token_id2primitive_id, symbol_map
@@ -174,8 +184,18 @@ def analyze(
174184
fold_times=self.fold_times,
175185
)
176186
rp_expr, token_id2primitive_id = rp_parser(inputs_seqs)
177-
rp_expr.try_unwrap_body_of_sole_symbol_token()
178-
rp_expr.try_recursive_inline_symbol_sole_used(token_id2primitive_id)
187+
trees = MakeNestedIndexRangeFromLetsListTokenRpExpr(rp_expr)
188+
189+
def get_debug_sprintf():
190+
var_and_vals = zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors)
191+
ret_lst = [
192+
*(f"{var}: {val}" for var, val in var_and_vals),
193+
"",
194+
str(rp_expr.body_rp_expr),
195+
]
196+
return "\n".join(ret_lst)
197+
198+
# Path("/tmp/rp_expr.txt").write_text(get_debug_sprintf())
179199

180200
num_primitives = len(token_id2primitive_id)
181201
symbol_map = dict(zip(rp_expr.symbol_token_ids, rp_expr.symbol_token_tensors))
@@ -187,6 +207,8 @@ def analyze(
187207
if i >= len(rp_expr.body_rp_expr):
188208
break
189209

210+
tree = trees[i]
211+
190212
target_body_tensor = rp_expr.body_rp_expr[i]
191213
seq_tokens = target_body_tensor.tolist()
192214

@@ -198,24 +220,18 @@ def analyze(
198220
)
199221
)
200222

201-
current_idx = 0
202-
split_positions = set()
203223
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
204224

205-
for token_id in seq_tokens:
206-
length = token2len.get(token_id, 1)
207-
is_pattern = token_id >= num_primitives
208-
209-
if is_pattern:
210-
if current_idx > 0:
211-
split_positions.add(current_idx)
212-
end_idx = current_idx + length
213-
if end_idx < total_len:
214-
split_positions.add(end_idx)
215-
216-
current_idx += length
217-
218-
sorted_splits = sorted(list(split_positions))
225+
sorted_splits = sorted(
226+
set(
227+
split_pos
228+
for start, end in tree.FilterSubTreeRangeBySize(
229+
self.min_seq_ops, self.max_seq_ops
230+
)
231+
for split_pos in (start, end)
232+
if end - start > 1
233+
)
234+
)
219235

220236
self._print_analysis(
221237
model_name, str(original_path), sorted_splits, total_len, full_model_ops
@@ -273,6 +289,8 @@ def main(args):
273289
window_size=args.window_size,
274290
fold_policy=args.fold_policy,
275291
fold_times=args.fold_times,
292+
min_seq_ops=args.min_seq_ops,
293+
max_seq_ops=args.max_seq_ops,
276294
)
277295
results = analyzer.analyze(args.op_names_path_prefix, args.model_list, args.device)
278296
if args.output_json:
@@ -329,5 +347,17 @@ def main(args):
329347
default=False,
330348
help="Resume process",
331349
)
350+
parser.add_argument(
351+
"--min-seq-ops",
352+
type=int,
353+
default=2,
354+
help="minimum number of sequence operators",
355+
)
356+
parser.add_argument(
357+
"--max-seq-ops",
358+
type=int,
359+
default=64,
360+
help="maximum number of sequence operators",
361+
)
332362
args = parser.parse_args()
333363
main(args)

0 commit comments

Comments
 (0)