33from pathlib import Path
44import json
55from itertools import groupby
6+ from dataclasses import dataclass
67
78
89class FusibleSubgraphRangesGenerator (SamplePass , ResumableSamplePassMixin ):
@@ -29,7 +30,9 @@ def sample_handled(self, rel_model_path: str) -> bool:
2930
3031 def resume (self , rel_model_path : str ):
3132 analyzer = self ._make_analyzer (rel_model_path )
32- output_obj = analyzer .analyze ()
33+ output_obj = {
34+ "subgraph_ranges" : analyzer .analyze (),
35+ }
3336 self ._save_output (rel_model_path , output_obj )
3437
3538 def _save_output (self , rel_model_path , output_obj ):
@@ -82,27 +85,125 @@ def __init__(
8285 self .start_offset_in_original_graph = start_offset_in_original_graph
8386
8487 def analyze (self ):
85- num_kernels_and_num_ops_list : list [
86- (int , list [int ])
87- ] = self ._make_num_kernels_and_num_ops_list ()
88- num_kernels_and_num_ops_list = sorted (
89- num_kernels_and_num_ops_list , key = lambda pair : pair [0 ]
90- )
91- num_ops_lists = [
92- sorted (num_ops_list )
88+ analysis_ctx = self ._make_analysis_ctx ()
89+ num_kernels_and_num_ops_list = analysis_ctx .num_kernels_and_num_ops_list
90+ # The tail num_kernels equals the head num_kernels for each num_ops_list
91+ naive_proposal_fused_num_ops_lists = [
92+ sorted (set (num_ops_list ))
9393 for _ , num_ops_list in num_kernels_and_num_ops_list
9494 if len (set (num_ops_list )) > 1
9595 ]
96+ proposal_fused_num_ops_lists = self ._merge_all_decreasing_num_ops_lists (
97+ analysis_ctx , naive_proposal_fused_num_ops_lists
98+ )
99+ return self ._create_subgraph_ranges_from_proposal (
100+ analysis_ctx ,
101+ proposal_fused_num_ops_lists ,
102+ )
103+
104+ def _merge_all_decreasing_num_ops_lists (self , analysis_ctx , num_ops_lists ):
105+ dead_loop_detect_cnt = 0
106+ kLimit = 99999
107+ while True :
108+ last_len_num_ops_lists = len (num_ops_lists )
109+ num_ops_lists = self ._merge_one_decreasing_num_ops_lists (
110+ analysis_ctx , num_ops_lists
111+ )
112+ assert last_len_num_ops_lists >= len (num_ops_lists )
113+ if last_len_num_ops_lists == len (num_ops_lists ):
114+ break
115+ dead_loop_detect_cnt += 1
116+ assert dead_loop_detect_cnt < kLimit , f"{ dead_loop_detect_cnt = } "
117+ return num_ops_lists
118+
119+ def _merge_one_decreasing_num_ops_lists (self , analysis_ctx , num_ops_lists ):
120+ merge_pos = self ._detect_mergable_decreasing_position (
121+ analysis_ctx , num_ops_lists
122+ )
123+ if merge_pos is None :
124+ return num_ops_lists
125+ assert merge_pos >= 0
126+ assert merge_pos < len (num_ops_lists ) - 1
127+ return [
128+ * num_ops_lists [:merge_pos ],
129+ [* num_ops_lists [merge_pos ], * num_ops_lists [merge_pos + 1 ]],
130+ * num_ops_lists [merge_pos + 2 :],
131+ ]
132+
133+ def _detect_mergable_decreasing_position (self , analysis_ctx , num_ops_lists ):
134+ def get_cur_tail_num_kernels (i ):
135+ return analysis_ctx .num_kernels4num_ops (num_ops_lists [i ][- 1 ])
136+
137+ def get_next_head_num_kernels (i ):
138+ return analysis_ctx .num_kernels4num_ops (num_ops_lists [i + 1 ][0 ])
139+
140+ for i in range (len (num_ops_lists ) - 1 ):
141+ assert len (num_ops_lists [i ]) > 1
142+ if get_cur_tail_num_kernels (i ) >= get_next_head_num_kernels (i ):
143+ return i
144+ return None
145+
146+ def _create_subgraph_ranges_from_proposal (
147+ self , analysis_ctx , proposal_fused_num_ops_lists
148+ ):
149+ # filter valid num_ops_list
150+
151+ def is_a_range (int_list ):
152+ assert len (int_list ) > 1
153+ return (int_list [- 1 ] + 1 ) - int_list [0 ] == len (int_list )
154+
155+ def have_any_increasing (num_ops_list : list [int ]):
156+ for i , cur_num_ops in enumerate (num_ops_list ):
157+ if i == 0 :
158+ continue
159+ cur_num_kernels = analysis_ctx .num_kernels4num_ops (cur_num_ops )
160+ last_num_kernels = analysis_ctx .num_kernels4num_ops (num_ops_list [i - 1 ])
161+ if cur_num_kernels > last_num_kernels :
162+ return True
163+ return False
164+
165+ def head_eq_tail (num_ops_list : list [int ]):
166+ return analysis_ctx .num_kernels4num_ops (
167+ num_ops_list [0 ]
168+ ) == analysis_ctx .num_kernels4num_ops (num_ops_list [- 1 ])
169+
170+ def head_gt_tail (num_ops_list : list [int ]):
171+ return analysis_ctx .num_kernels4num_ops (
172+ num_ops_list [0 ]
173+ ) > analysis_ctx .num_kernels4num_ops (num_ops_list [- 1 ])
174+
175+ def valid_fused_ops (num_ops_list : list [int ]):
176+ if head_gt_tail (num_ops_list ):
177+ return True
178+ if head_eq_tail (num_ops_list ):
179+ return not have_any_increasing (num_ops_list )
180+ return False
181+
182+ proposal_fused_num_ops_lists = [
183+ sorted (set (num_ops_list )) for num_ops_list in proposal_fused_num_ops_lists
184+ ]
185+ num_ops_lists = [
186+ num_ops_list
187+ for num_ops_list in proposal_fused_num_ops_lists
188+ if len (num_ops_list ) > 1
189+ if is_a_range (num_ops_list )
190+ if valid_fused_ops (num_ops_list )
191+ ]
96192 fusible_subgraph_ranges = [
97193 (start , end )
98194 for num_ops_list in num_ops_lists
99195 for start in [num_ops_list [0 ] - 1 ]
100196 for end in [num_ops_list [- 1 ]]
101197 ]
198+
102199 # sorted by `start`
103- fusible_subgraph_ranges = sorted (
104- fusible_subgraph_ranges , key = lambda pair : pair [0 ]
105- )
200+ def range_sort_key (pair ):
201+ start , end = pair
202+ # smaller `start` first
203+ # bigger `end` first
204+ return (start , - end )
205+
206+ fusible_subgraph_ranges = sorted (fusible_subgraph_ranges , key = range_sort_key )
106207 # remove shadowed
107208 fusible_subgraph_ranges = [
108209 fusible_subgraph_ranges [i ]
@@ -112,6 +213,15 @@ def analyze(self):
112213 ]
113214 return fusible_subgraph_ranges
114215
216+ def _make_analysis_ctx (self ):
217+ return AnalysisContext (
218+ num_kernels_and_num_ops_list = self ._make_num_kernels_and_num_ops_list (),
219+ num_ops2num_kernels = self ._make_num_ops2num_kernels (),
220+ )
221+
222+ def _make_num_ops2num_kernels (self ):
223+ return dict (zip (self .num_subgraph_ops_list , self .num_subgraph_kernels_list ))
224+
115225 def _make_num_kernels_and_num_ops_list (self ):
116226 num_kernels_and_num_ops = zip (
117227 self .num_subgraph_kernels_list ,
@@ -121,7 +231,10 @@ def _make_num_kernels_and_num_ops_list(self):
121231 def get_num_kernels (pair ):
122232 return pair [0 ]
123233
124- num_kernels_and_num_ops = sorted (num_kernels_and_num_ops , key = get_num_kernels )
234+ def get_num_ops (pair ):
235+ return pair [1 ]
236+
237+ num_kernels_and_num_ops = sorted (num_kernels_and_num_ops , key = get_num_ops )
125238 grouped_num_kernels_and_num_ops = groupby (
126239 num_kernels_and_num_ops , key = get_num_kernels
127240 )
@@ -130,3 +243,12 @@ def get_num_kernels(pair):
130243 for num_kernels , group in grouped_num_kernels_and_num_ops
131244 ]
132245 return num_kernels_and_num_ops_list
246+
247+
248+ @dataclass
249+ class AnalysisContext :
250+ num_kernels_and_num_ops_list : list [(int , list [int ])]
251+ num_ops2num_kernels : dict [int , int ]
252+
253+ def num_kernels4num_ops (self , num_ops : int ):
254+ return self .num_ops2num_kernels [num_ops ]
0 commit comments