@@ -58,49 +58,14 @@ def _regenerate_viewer_if_possible(output_dir: Path) -> bool:
5858
5959 Returns True if viewer was regenerated, False otherwise.
6060 """
61- from openadapt_ml .training .trainer import _enhance_comparison_to_unified_viewer
62-
63- # Look for base comparison file
64- base_file = output_dir / "comparison.html"
65- if not base_file .exists ():
66- # Try to find any comparison HTML
67- comparison_files = list (output_dir .glob ("*comparison*.html" ))
68- if comparison_files :
69- base_file = comparison_files [0 ]
70- else :
71- return False
72-
73- # Load predictions from checkpoint files
74- predictions_by_checkpoint = {"None" : []}
75- for pred_file in output_dir .glob ("predictions_*.json" ):
76- checkpoint_name = pred_file .stem .replace ("predictions_" , "" )
77- if "epoch" in checkpoint_name :
78- display_name = checkpoint_name .replace ("epoch" , "Epoch " ).replace ("_" , " " ).title ()
79- elif checkpoint_name == "preview" :
80- display_name = "Preview"
81- else :
82- display_name = checkpoint_name .title ()
83-
84- try :
85- with open (pred_file ) as f :
86- predictions_by_checkpoint [display_name ] = json .load (f )
87- except json .JSONDecodeError :
88- pass
89-
90- # Get capture info
91- capture_id = "capture"
92- goal = "Complete the recorded workflow"
61+ from openadapt_ml .training .trainer import generate_unified_viewer_from_output_dir
9362
9463 try :
95- _enhance_comparison_to_unified_viewer (
96- base_file ,
97- predictions_by_checkpoint ,
98- output_dir / "viewer.html" ,
99- capture_id ,
100- goal ,
101- )
102- print (f"Regenerated viewer: { output_dir / 'viewer.html' } " )
103- return True
64+ viewer_path = generate_unified_viewer_from_output_dir (output_dir )
65+ if viewer_path :
66+ print (f"Regenerated viewer: { viewer_path } " )
67+ return True
68+ return False
10469 except Exception as e :
10570 print (f"Could not regenerate viewer: { e } " )
10671 return False
@@ -399,7 +364,7 @@ def cmd_viewer(args: argparse.Namespace) -> int:
399364 """Regenerate viewer from local training output."""
400365 from openadapt_ml .training .trainer import (
401366 generate_training_dashboard ,
402- _enhance_comparison_to_unified_viewer ,
367+ generate_unified_viewer_from_output_dir ,
403368 TrainingState ,
404369 TrainingConfig ,
405370 )
@@ -435,66 +400,12 @@ def cmd_viewer(args: argparse.Namespace) -> int:
435400 (current_dir / "dashboard.html" ).write_text (dashboard_html )
436401 print (f" Regenerated: dashboard.html" )
437402
438- # Find comparison HTML to enhance
439- # Try epoch-specific files first, then fall back to generic comparison.html
440- comparison_files = list (current_dir .glob ("comparison_epoch*.html" ))
441- base_file = None
442- if comparison_files :
443- # Use the latest epoch comparison
444- base_file = sorted (comparison_files )[- 1 ]
445- elif (current_dir / "comparison.html" ).exists ():
446- base_file = current_dir / "comparison.html"
447-
448- if base_file :
449- print (f" Using base file: { base_file .name } " )
450-
451- # Load all prediction files
452- predictions_by_checkpoint = {"None" : []}
453- for pred_file in current_dir .glob ("predictions_*.json" ):
454- checkpoint_name = pred_file .stem .replace ("predictions_" , "" )
455- # Map to display name
456- if "epoch" in checkpoint_name :
457- display_name = checkpoint_name .replace ("epoch" , "Epoch " ).replace ("_" , " " ).title ()
458- elif checkpoint_name == "preview" :
459- display_name = "Preview"
460- else :
461- display_name = checkpoint_name .title ()
462-
463- try :
464- with open (pred_file ) as f :
465- data = json .load (f )
466- # Handle predictions JSON structure
467- if isinstance (data , dict ) and "predictions" in data :
468- predictions_by_checkpoint [display_name ] = data ["predictions" ]
469- else :
470- predictions_by_checkpoint [display_name ] = data
471- print (f" Loaded predictions from { pred_file .name } " )
472- except json .JSONDecodeError :
473- print (f" Warning: Could not parse { pred_file .name } " )
474-
475- # Get capture info from training log
476- capture_id = "capture"
477- goal = "Complete the recorded workflow"
478- if log_file .exists ():
479- try :
480- with open (log_file ) as f :
481- log_data = json .load (f )
482- capture_path = log_data .get ("capture_path" , "" )
483- if capture_path :
484- capture_id = Path (capture_path ).name
485- except (json .JSONDecodeError , KeyError ):
486- pass
487-
488- _enhance_comparison_to_unified_viewer (
489- base_file ,
490- predictions_by_checkpoint ,
491- current_dir / "viewer.html" ,
492- capture_id ,
493- goal ,
494- )
495- print (f"\n Generated: { current_dir / 'viewer.html' } " )
403+ # Generate unified viewer using consolidated function
404+ viewer_path = generate_unified_viewer_from_output_dir (current_dir )
405+ if viewer_path :
406+ print (f"\n Generated: { viewer_path } " )
496407 else :
497- print ("\n No comparison.html found. Run comparison first or copy from capture directory." )
408+ print ("\n No comparison data found. Run comparison first or copy from capture directory." )
498409
499410 if args .open :
500411 webbrowser .open (str (current_dir / "viewer.html" ))
0 commit comments