66import torch
77import torch .nn as nn
88from graph_net .torch .rp_expr .rp_expr_parser import RpExprParser
9+ from graph_net .torch .rp_expr .rp_expr_util import (
10+ MakeNestedIndexRangeFromLetsListTokenRpExpr ,
11+ )
912from graph_net .torch .fx_graph_module_util import get_torch_module_and_inputs
1013from graph_net .torch .fx_graph_parse_util import parse_sole_graph_module_without_varify
1114
@@ -92,11 +95,18 @@ def _extract_ops(self, model_path: str) -> List[str]:
9295
9396class SplitAnalyzer :
9497 def __init__ (
95- self , window_size : int = 10 , fold_policy : str = "default" , fold_times : int = 0
98+ self ,
99+ window_size : int = 10 ,
100+ fold_policy : str = "default" ,
101+ fold_times : int = 0 ,
102+ min_seq_ops : int = 2 ,
103+ max_seq_ops : int = 64 ,
96104 ):
97105 self .window_size = window_size
98106 self .fold_policy = fold_policy
99107 self .fold_times = fold_times
108+ self .min_seq_ops = min_seq_ops
109+ self .max_seq_ops = max_seq_ops
100110
101111 def _resolve_token_to_ops (
102112 self , tid , num_primitives , token_id2primitive_id , symbol_map
@@ -174,8 +184,18 @@ def analyze(
174184 fold_times = self .fold_times ,
175185 )
176186 rp_expr , token_id2primitive_id = rp_parser (inputs_seqs )
177- rp_expr .try_unwrap_body_of_sole_symbol_token ()
178- rp_expr .try_recursive_inline_symbol_sole_used (token_id2primitive_id )
187+ trees = MakeNestedIndexRangeFromLetsListTokenRpExpr (rp_expr )
188+
189+ def get_debug_sprintf ():
190+ var_and_vals = zip (rp_expr .symbol_token_ids , rp_expr .symbol_token_tensors )
191+ ret_lst = [
192+ * (f"{ var } : { val } " for var , val in var_and_vals ),
193+ "" ,
194+ str (rp_expr .body_rp_expr ),
195+ ]
196+ return "\n " .join (ret_lst )
197+
198+ # Path("/tmp/rp_expr.txt").write_text(get_debug_sprintf())
179199
180200 num_primitives = len (token_id2primitive_id )
181201 symbol_map = dict (zip (rp_expr .symbol_token_ids , rp_expr .symbol_token_tensors ))
@@ -187,6 +207,8 @@ def analyze(
187207 if i >= len (rp_expr .body_rp_expr ):
188208 break
189209
210+ tree = trees [i ]
211+
190212 target_body_tensor = rp_expr .body_rp_expr [i ]
191213 seq_tokens = target_body_tensor .tolist ()
192214
@@ -198,24 +220,18 @@ def analyze(
198220 )
199221 )
200222
201- current_idx = 0
202- split_positions = set ()
203223 total_len = sum (token2len .get (t , 1 ) for t in seq_tokens )
204224
205- for token_id in seq_tokens :
206- length = token2len .get (token_id , 1 )
207- is_pattern = token_id >= num_primitives
208-
209- if is_pattern :
210- if current_idx > 0 :
211- split_positions .add (current_idx )
212- end_idx = current_idx + length
213- if end_idx < total_len :
214- split_positions .add (end_idx )
215-
216- current_idx += length
217-
218- sorted_splits = sorted (list (split_positions ))
225+ sorted_splits = sorted (
226+ set (
227+ split_pos
228+ for start , end in tree .FilterSubTreeRangeBySize (
229+ self .min_seq_ops , self .max_seq_ops
230+ )
231+ for split_pos in (start , end )
232+ if end - start > 1
233+ )
234+ )
219235
220236 self ._print_analysis (
221237 model_name , str (original_path ), sorted_splits , total_len , full_model_ops
@@ -273,6 +289,8 @@ def main(args):
273289 window_size = args .window_size ,
274290 fold_policy = args .fold_policy ,
275291 fold_times = args .fold_times ,
292+ min_seq_ops = args .min_seq_ops ,
293+ max_seq_ops = args .max_seq_ops ,
276294 )
277295 results = analyzer .analyze (args .op_names_path_prefix , args .model_list , args .device )
278296 if args .output_json :
@@ -329,5 +347,17 @@ def main(args):
329347 default = False ,
330348 help = "Resume process" ,
331349 )
350+ parser .add_argument (
351+ "--min-seq-ops" ,
352+ type = int ,
353+ default = 2 ,
354+ help = "minimum number of sequence operators" ,
355+ )
356+ parser .add_argument (
357+ "--max-seq-ops" ,
358+ type = int ,
359+ default = 64 ,
360+ help = "maximum number of sequence operators" ,
361+ )
332362 args = parser .parse_args ()
333363 main (args )
0 commit comments