@@ -279,27 +279,18 @@ def run_evaluation(
279279 ), f"[ERROR] test failed for { samples_dir } , please check the log."
280280
281281
282- def reconstruct_subgraph_size (split_positions : List [int ]) -> List [list ]:
283- """Reconstructs subgraph size based on sorted split positions."""
284- deduplicated_splits = sorted (set (split_positions ))
285-
286- subgraph_size = [
287- deduplicated_splits [i : i + 2 ] for i in range (len (deduplicated_splits ) - 1 )
288- ]
289- return subgraph_size
290-
291-
292- def calculate_split_positions_for_subgraph (subgraph_range , max_subgraph_size ):
293- assert isinstance (subgraph_range , (list , tuple )) and len (subgraph_range ) == 2
294-
295- # subgraph_size: the start and end position in original model.
296- start_pos , end_pos = subgraph_range
297- end_pos = kMaxGraphSize if end_pos == float ("inf" ) else end_pos
282+ def reconstruct_split_positions_for_subgraph (
283+ split_positions , subgraph_idx , max_subgraph_size
284+ ):
285+ assert (
286+ subgraph_idx < len (split_positions ) - 1
287+ ), f"subgraph_idx { subgraph_idx } is out of bounds of split_positions: { split_positions } ."
298288
299- split_positions = set (
289+ start_pos , end_pos = split_positions [subgraph_idx : subgraph_idx + 2 ]
290+ new_split_positions = set (
300291 range (start_pos , end_pos + max_subgraph_size - 1 , max_subgraph_size )
301292 )
302- return sorted (list (set ( split_positions ) ))
293+ return sorted (list (new_split_positions ))
303294
304295
305296def generate_initial_tasks (args ):
@@ -310,17 +301,14 @@ def generate_initial_tasks(args):
310301 tasks_map = {}
311302 max_subgraph_size = args .max_subgraph_size
312303
304+ initial_split_positions = reconstruct_split_positions_for_subgraph (
305+ [0 , kMaxGraphSize ], 0 , max_subgraph_size
306+ )
313307 for model_path in initial_failures :
314308 model_name = get_model_name_with_subgraph_tag (model_path )
315-
316- initial_range = [0 , kMaxGraphSize ]
317- initial_splits = calculate_split_positions_for_subgraph (
318- initial_range , max_subgraph_size
319- )
320-
321309 tasks_map [model_name ] = {
322310 "original_path" : model_path ,
323- "split_positions" : initial_splits ,
311+ "split_positions" : initial_split_positions ,
324312 }
325313
326314 running_states = {
@@ -355,19 +343,14 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
355343
356344 for subgraph_path in sorted (prev_config .incorrect_models ):
357345 model_name , subgraph_idx = extract_model_name_and_subgraph_idx (subgraph_path )
346+ print (f"{ subgraph_path = } " )
358347
359348 assert model_name in prev_tasks_map
360349 pre_task_for_model = prev_tasks_map [model_name ]
361350
362351 prev_split_positions = pre_task_for_model .get ("split_positions" , [])
363- subgraph_ranges = reconstruct_subgraph_size (prev_split_positions )
364-
365- assert subgraph_idx < len (
366- subgraph_ranges
367- ), f"subgraph_idx { subgraph_idx } is out of bounds for { model_name } (previous split_positions: { prev_split_positions } )"
368-
369- split_positions = calculate_split_positions_for_subgraph (
370- subgraph_ranges [subgraph_idx ], max_subgraph_size
352+ split_positions = reconstruct_split_positions_for_subgraph (
353+ prev_split_positions , subgraph_idx , max_subgraph_size
371354 )
372355 if model_name not in tasks_map :
373356 tasks_map [model_name ] = {
@@ -381,6 +364,7 @@ def generate_successor_tasks(base_output_dir, current_pass_id):
381364 tasks_map [model_name ]["split_positions" ] = list (
382365 sorted (set (merged_split_positions ))
383366 )
367+ print (f"{ tasks_map = } " )
384368
385369 return tasks_map , max_subgraph_size , prev_config .running_states
386370
@@ -409,6 +393,7 @@ def prepare_tasks_and_verify(args, current_pass_id, base_output_dir):
409393 )
410394 sys .exit (0 )
411395
396+ sys .exit (0 )
412397 return tasks_map , max_subgraph_size , running_states
413398
414399
@@ -450,8 +435,8 @@ def execute_decomposition_phase(max_subgraph_size, tasks_map, framework, workspa
450435 split_positions = task_info ["split_positions" ]
451436 if not split_positions or len (split_positions ) < 2 :
452437 continue
453- new_split_positions = calculate_split_positions_for_subgraph (
454- split_positions [ 0 : 2 ] , max_subgraph_size
438+ new_split_positions = reconstruct_split_positions_for_subgraph (
439+ split_positions , 0 , max_subgraph_size
455440 )
456441 task_info ["split_positions" ] = new_split_positions
457442 else :
0 commit comments